Compare commits

..

1 Commits

Author SHA1 Message Date
Johan B.W. de Vries
d1854d7a38 Cleanup to type5 solver
- Replaces weird CheckResult class with proper data classes.
- When a constraint results in new constraints, those are
  treated first
- Minor readability change to function return type
- Code cleanup
- TODO cleanup
- Typo fixes
2025-08-30 14:45:12 +02:00
15 changed files with 114 additions and 279 deletions

View File

@ -24,7 +24,4 @@
- Read https://bytecodealliance.org/articles/multi-value-all-the-wasm
- Implement type class 'inheritance'
- Remove FunctionInstance, replace with a substitutions dict
- See phft2 in fromast.py:expression_function_call
- Move unify into the typeconstraints (or other way around) - it's done on two levels now (partly in solver)
- Rework type classes - already started on a separate dir for those, but quite a few things are still in other places.

View File

@ -45,7 +45,6 @@ class BuildBase[G]:
__slots__ = (
'dynamic_array_type5_constructor',
'function_type5_constructor',
'io_type5_constructor',
'static_array_type5_constructor',
'tuple_type5_constructor_map',
@ -84,18 +83,6 @@ class BuildBase[G]:
See type5_make_function and type5_is_function.
"""
io_type5_constructor: type5typeexpr.TypeConstructor
"""
Constructor for IO.
An IO function is a function that can have side effects.
It can do input or output. Other functions cannot have
side effects, and can only return a value based on the
input given.
See type5_make_io and type5_is_io.
"""
static_array_type5_constructor: type5typeexpr.TypeConstructor
"""
Constructor for arrays of compiled time determined length.
@ -220,7 +207,6 @@ class BuildBase[G]:
self.dynamic_array_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> S, name="dynamic_array")
self.function_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> (S >> S), name="function")
self.io_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> S, name="IO")
self.static_array_type5_constructor = type5typeexpr.TypeConstructor(kind=N >> (S >> S), name='static_array')
self.tuple_type5_constructor_map = {}
@ -358,19 +344,6 @@ class BuildBase[G]:
return my_args + more_args
def type5_make_io(self, arg: type5typeexpr.TypeExpr) -> type5typeexpr.TypeApplication:
return type5typeexpr.TypeApplication(
constructor=self.io_type5_constructor,
argument=arg
)
def type5_is_io(self, typeexpr: type5typeexpr.TypeExpr | type5constrainedexpr.ConstrainedExpr) -> type5typeexpr.TypeExpr | None:
if not isinstance(typeexpr, type5typeexpr.TypeApplication):
return None
if typeexpr.constructor != self.io_type5_constructor:
return None
return typeexpr.argument
def type5_make_tuple(self, args: Sequence[type5typeexpr.TypeExpr]) -> type5typeexpr.TypeApplication:
if not args:
raise TypeError("Tuples must at least one field")

View File

@ -22,7 +22,6 @@ from .typeclasses import (
fractional,
integral,
intnum,
monad,
natnum,
ord,
promotable,
@ -69,7 +68,6 @@ class BuildDefault(BuildBase[Generator]):
integral,
foldable, subscriptable,
sized,
monad,
]
for tc in tc_list:

View File

@ -1,26 +0,0 @@
"""
The Monad type class is defined for type constructors that cause one thing to happen /after/ another.
"""
from __future__ import annotations
from typing import Any
from ...type5.constrainedexpr import ConstrainedExpr
from ...type5.kindexpr import Star
from ...type5.typeexpr import TypeVariable
from ...typeclass import TypeClass, TypeClassConstraint
from ...wasmgenerator import Generator as WasmGenerator
from ..base import BuildBase
def load(build: BuildBase[Any]) -> None:
a = TypeVariable(kind=Star(), name='a')
Monad = TypeClass('Monad', (a, ), methods={}, operators={})
build.register_type_class(Monad)
def wasm(build: BuildBase[WasmGenerator]) -> None:
Monad = build.type_classes['Monad']
build.instance_type_class(Monad, build.io_type5_constructor)

View File

@ -36,10 +36,6 @@ class BuildTypeRouter[T](TypeRouter[T]):
if fn_args is not None:
return self.when_function(fn_args)
io_args = self.build.type5_is_io(typ)
if io_args is not None:
return self.when_io(io_args)
sa_args = self.build.type5_is_static_array(typ)
if sa_args is not None:
sa_len, sa_typ = sa_args
@ -63,9 +59,6 @@ class BuildTypeRouter[T](TypeRouter[T]):
def when_function(self, fn_args: list[TypeExpr]) -> T:
raise NotImplementedError
def when_io(self, io_arg: TypeExpr) -> T:
raise NotImplementedError
def when_struct(self, typ: Record) -> T:
raise NotImplementedError
@ -101,9 +94,6 @@ class TypeName(BuildTypeRouter[str]):
def when_function(self, fn_args: list[TypeExpr]) -> str:
return 'Callable[' + ', '.join(map(self, fn_args)) + ']'
def when_io(self, io_arg: TypeExpr) -> str:
return 'IO[' + self(io_arg) + ']'
def when_static_array(self, sa_len: int, sa_typ: TypeExpr) -> str:
return f'{self(sa_typ)}[{sa_len}]'

View File

@ -124,10 +124,6 @@ def statement(inp: ourlang.Statement) -> Statements:
yield ''
return
if isinstance(inp, ourlang.StatementCall):
yield expression(inp.call)
return
raise NotImplementedError(statement, inp)
def function(mod: ourlang.Module[Any], inp: ourlang.Function) -> str:

View File

@ -43,11 +43,6 @@ def type5(mod: ourlang.Module[WasmGenerator], inp: TypeExpr) -> wasm.WasmType:
Types are used for example in WebAssembly function parameters
and return types.
"""
io_arg = mod.build.type5_is_io(inp)
if io_arg is not None:
# IO is type constructor that only exists on the typing layer
inp = io_arg
typ_info = mod.build.type_info_map.get(inp.name)
if typ_info is None:
typ_info = mod.build.type_info_constructed
@ -409,9 +404,6 @@ def statement_if(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ou
# for stat in inp.else_statements:
# statement(wgn, stat)
def statement_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourlang.Function, inp: ourlang.StatementCall) -> None:
expression(wgn, mod, inp.call)
def statement(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourlang.Function, inp: ourlang.Statement) -> None:
"""
Compile: any statement
@ -424,10 +416,6 @@ def statement(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourla
statement_if(wgn, mod, fun, inp)
return
if isinstance(inp, ourlang.StatementCall):
statement_call(wgn, mod, fun, inp)
return
if isinstance(inp, ourlang.StatementPass):
return

View File

@ -282,24 +282,6 @@ class StatementReturn(Statement):
def __repr__(self) -> str:
return f'StatementReturn({repr(self.value)})'
class StatementCall(Statement):
"""
A function call within a function.
Executing is deferred to the given function until it completes.
"""
__slots__ = ('call')
call: FunctionCall
def __init__(self, call: FunctionCall, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.call = call
def __repr__(self) -> str:
return f'StatementCall({repr(self.call)})'
class StatementIf(Statement):
"""
An if statement within a function

View File

@ -25,7 +25,6 @@ from .ourlang import (
ModuleDataBlock,
SourceRef,
Statement,
StatementCall,
StatementIf,
StatementPass,
StatementReturn,
@ -362,9 +361,6 @@ class OurVisitor[G]:
return result
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
return StatementCall(self.visit_Module_FunctionDef_Call(module, function, our_locals, node.value), srf(module, node))
if isinstance(node, ast.Pass):
return StatementPass(srf(module, node))
@ -488,7 +484,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Call) -> FunctionCall:
def visit_Module_FunctionDef_Call(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -651,15 +647,6 @@ class OurVisitor[G]:
for e in func_arg_types
])
if isinstance(node.value, ast.Name) and node.value.id == 'IO':
assert isinstance(node.slice, ast.Name) or (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 0)
return module.build.type5_make_io(
self.visit_type5(module, node.slice)
)
# TODO: This u32[...] business is messing up the other type constructors
if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index')
@ -685,9 +672,6 @@ class OurVisitor[G]:
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.elts:
return module.build.unit_type5
return module.build.type5_make_tuple(
[self.visit_type5(module, elt) for elt in node.elts],
)

View File

@ -10,7 +10,7 @@ from .typeexpr import TypeExpr, TypeVariable, instantiate
class TypeConstraint:
"""
Base class for type contraints
Base class for type constraints
"""
__slots__ = ()

View File

@ -75,6 +75,33 @@ class Context:
assert tvar not in self.ptst_update
self.ptst_update[tvar] = (arg, orig_var)
class Success:
"""
The contsraint checks out.
Nothing new was learned and nothing new needs to be checked.
"""
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return '(ok)'
class SkipForNow:
"""
Not enough information to resolve this constraint
"""
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return '(skip for now)'
class ConstraintList(list['ConstraintBase']):
"""
A new list of constraints.
Sometimes, checking one constraint means you get a list of new contraints.
"""
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return f'(got {len(self)} new constraints)'
@dataclasses.dataclass
class Failure:
"""
@ -82,52 +109,38 @@ class Failure:
"""
msg: str
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return f'ERR: {self.msg}'
@dataclasses.dataclass
class ReplaceVariable:
"""
A variable should be replaced.
Either by another variable or by a (concrete) type.
"""
var: TypeVariable
typ: TypeExpr
@dataclasses.dataclass
class CheckResult:
# TODO: Refactor this, don't think we use most of the variants
_: dataclasses.KW_ONLY
done: bool = True
replace: ReplaceVariable | None = None
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
failures: list[Failure] = dataclasses.field(default_factory=list)
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
if not self.done and not self.replace and not self.new_constraints and not self.failures:
return '(skip for now)'
return f'{{{self.var.name} := {type_namer(self.typ)}}}'
if self.done and not self.replace and not self.new_constraints and not self.failures:
return '(ok)'
if self.done and self.replace and not self.new_constraints and not self.failures:
return f'{{{self.replace.var.name} := {type_namer(self.replace.typ)}}}'
if self.done and not self.replace and self.new_constraints and not self.failures:
return f'(got {len(self.new_constraints)} new constraints)'
if self.done and not self.replace and not self.new_constraints and self.failures:
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
return f'{self.done} {self.replace} {self.new_constraints} {self.failures}'
def skip_for_now() -> CheckResult:
return CheckResult(done=False)
def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult:
return CheckResult(replace=ReplaceVariable(var, typ))
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
return CheckResult(new_constraints=list(lst))
CheckResult: TypeAlias = Success | SkipForNow | ConstraintList | Failure | ReplaceVariable
def ok() -> CheckResult:
return CheckResult(done=True)
return Success()
def skip_for_now() -> CheckResult:
return SkipForNow()
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
return ConstraintList(lst)
def fail(msg: str) -> CheckResult:
return CheckResult(failures=[Failure(msg)])
return Failure(msg)
def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult:
return ReplaceVariable(var, typ)
class ConstraintBase:
__slots__ = ("ctx", "sourceref", )
@ -399,7 +412,7 @@ class CanBeSubscriptedConstraint(ConstraintBase):
def check(self) -> CheckResult:
if not is_concrete(self.container_type5):
return CheckResult(done=False)
return skip_for_now()
tp_args = self.ctx.build.type5_is_tuple(self.container_type5)
if tp_args is not None:
@ -462,7 +475,7 @@ class CanAccessStructMemberConstraint(ConstraintBase):
def check(self) -> CheckResult:
if not is_concrete(self.struct_type5):
return CheckResult(done=False)
return skip_for_now()
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
if st_args is None:
@ -514,11 +527,11 @@ class FromTupleConstraint(ConstraintBase):
def check(self) -> CheckResult:
if not is_concrete(self.ret_type5):
return CheckResult(done=False)
return skip_for_now()
da_arg = self.ctx.build.type5_is_dynamic_array(self.ret_type5)
if da_arg is not None:
return CheckResult(new_constraints=[
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, da_arg, x)
for x in self.member_type5_list
])
@ -529,7 +542,7 @@ class FromTupleConstraint(ConstraintBase):
if sa_len != len(self.member_type5_list):
return fail('Tuple element count mismatch')
return CheckResult(new_constraints=[
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, x)
for x in self.member_type5_list
])
@ -539,7 +552,7 @@ class FromTupleConstraint(ConstraintBase):
if len(tp_args) != len(self.member_type5_list):
return fail('Tuple element count mismatch')
return CheckResult(new_constraints=[
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, act_typ, exp_typ)
for act_typ, exp_typ in zip(tp_args, self.member_type5_list, strict=True)
])

View File

@ -15,7 +15,7 @@ from .constraints import (
TypeClassInstanceExistsConstraint,
UnifyTypesConstraint,
)
from .kindexpr import KindExpr, Star
from .kindexpr import KindExpr
from .typeexpr import TypeApplication, TypeExpr, TypeVariable, is_concrete
ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -233,16 +233,15 @@ def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement
if fun.type5 is None:
raise NotImplementedError("Deducing function type - you'll have to annotate it.")
args = ctx.build.type5_is_function(fun.type5)
assert args is not None
type5 = args[-1]
# This is a hack to allow return statement in pure and non pure functions
if (io_arg := ctx.build.type5_is_io(type5)):
type5 = io_arg
if isinstance(fun.type5, TypeApplication):
args = ctx.build.type5_is_function(fun.type5)
assert args is not None
type5 = args[-1]
else:
type5 = fun.type5.expr if isinstance(fun.type5, ConstrainedExpr) else fun.type5
yield from expression(ctx, inp.value, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft, prefix=f'{fun.name} returns')
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft, prefix=f'{fun.name}(...)')
def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf) -> ConstraintGenerator:
test_phft = ctx.make_placeholder(inp.test)
@ -257,27 +256,6 @@ def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf)
for stmt in inp.else_statements:
yield from statement(ctx, fun, stmt)
def statement_call(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementCall) -> ConstraintGenerator:
if fun.type5 is None:
raise NotImplementedError("Deducing function type - you'll have to annotate it.")
fn_args = ctx.build.type5_is_function(fun.type5)
assert fn_args is not None
fn_ret = fn_args[-1]
call_phft = ctx.make_placeholder(inp.call)
S = Star()
t_phft = ctx.make_placeholder(kind=S >> S)
a_phft = ctx.make_placeholder(kind=S)
yield from expression_function_call(ctx, inp.call, call_phft)
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, 'Monad', [t_phft])
yield UnifyTypesConstraint(ctx, inp.sourceref, TypeApplication(constructor=t_phft, argument=a_phft), fn_ret)
yield UnifyTypesConstraint(ctx, inp.sourceref, TypeApplication(constructor=t_phft, argument=ctx.build.unit_type5), call_phft)
def statement(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement) -> ConstraintGenerator:
if isinstance(inp, ourlang.StatementReturn):
yield from statement_return(ctx, fun, inp)
@ -287,18 +265,12 @@ def statement(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement) -> Co
yield from statement_if(ctx, fun, inp)
return
if isinstance(inp, ourlang.StatementCall):
yield from statement_call(ctx, fun, inp)
return
raise NotImplementedError(inp)
def function(ctx: Context, inp: ourlang.Function) -> ConstraintGenerator:
for stmt in inp.statements:
yield from statement(ctx, inp, stmt)
# TODO: If function is imported or exported, it should be an IO[..] function
def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> ConstraintGenerator:
phft = ctx.make_placeholder(inp.constant)

View File

@ -1,7 +1,7 @@
from typing import Any
from ..ourlang import Module
from .constraints import ConstraintBase, Context
from .constraints import ConstraintBase, ConstraintList, Context, Failure, ReplaceVariable, SkipForNow, Success
from .fromast import phasm_type5_generate_constraints
from .typeexpr import TypeExpr, TypeVariable, is_concrete, replace_variable
@ -30,58 +30,68 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
print("Validating")
new_constraint_list: list[ConstraintBase] = []
for constraint in sorted(constraint_list, key=lambda x: x.complexity()):
# Iterate using a while and pop since on ReplaceVariable
# we want to iterate over the list as well, and since on
# ConstraintList we want to treat those first.
remaining_constraint_list = sorted(constraint_list, key=lambda x: x.complexity())
while remaining_constraint_list:
constraint = remaining_constraint_list.pop(0)
result = constraint.check()
if verbose:
print(f"{constraint.sourceref!s} {constraint!s}")
print(f"{constraint.sourceref!s} => {result.to_str(inp.build.type5_name)}")
if not result:
# None or empty list
# Means it checks out and we don't need do anything
continue
match result:
case Success():
# This constraint was valid
continue
case SkipForNow():
# We have to check later
new_constraint_list.append(constraint)
continue
case ConstraintList(items):
# This constraint was valid, but we have new once
# Do this as the first next items, so when users are reading the
# solver output they don't need to context switch.
remaining_constraint_list = items + remaining_constraint_list
if result.replace is not None:
action_var = result.replace.var
assert action_var not in placeholder_types # When does this happen?
if verbose:
for new_const in items:
print(f"{constraint.sourceref!s} => + {new_const!s}")
action_typ: TypeExpr = result.replace.typ
assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen?
continue
case Failure(msg):
error_list.append((str(constraint.sourceref), str(constraint), msg, ))
continue
case ReplaceVariable(action_var, action_typ):
assert action_var not in placeholder_types # When does this happen?
assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen?
assert action_var != action_typ # When does this happen?
assert action_var != action_typ # When does this happen?
# Ensure all existing found types are updated
# if they have this variable somewhere inside them.
placeholder_types = {
k: replace_variable(v, action_var, action_typ)
for k, v in placeholder_types.items()
}
# Ensure all existing found types are updated
placeholder_types = {
k: replace_variable(v, action_var, action_typ)
for k, v in placeholder_types.items()
}
placeholder_types[action_var] = action_typ
# Add the new variable to the registry
placeholder_types[action_var] = action_typ
for oth_const in new_constraint_list + constraint_list:
if oth_const is constraint and result.done:
continue
# Also update all constraints that may refer to this variable
# that they now have more detailed information.
for oth_const in new_constraint_list + remaining_constraint_list:
old_str = str(oth_const)
oth_const.replace_variable(action_var, action_typ)
new_str = str(oth_const)
old_str = str(oth_const)
oth_const.replace_variable(action_var, action_typ)
new_str = str(oth_const)
if verbose and old_str != new_str:
print(f"{oth_const.sourceref!s} => - {old_str!s}")
print(f"{oth_const.sourceref!s} => + {new_str!s}")
if verbose and old_str != new_str:
print(f"{oth_const.sourceref!s} => - {old_str!s}")
print(f"{oth_const.sourceref!s} => + {new_str!s}")
for failure in result.failures:
error_list.append((str(constraint.sourceref), str(constraint), failure.msg, ))
new_constraint_list.extend(result.new_constraints)
if verbose:
for new_const in result.new_constraints:
print(f"{oth_const.sourceref!s} => + {new_const!s}")
if result.done:
continue
new_constraint_list.append(constraint)
continue
if error_list:
raise Type5SolverException(error_list)

View File

@ -139,10 +139,6 @@ class Extractor(BuildTypeRouter[ExtractorFunc]):
return DynamicArrayExtractor(self.access, self(da_arg))
def when_io(self, io_arg: TypeExpr) -> ExtractorFunc:
# IO is a type only annotation, it is not related to allocation
return self(io_arg)
def when_static_array(self, sa_len: int, sa_typ: TypeExpr) -> ExtractorFunc:
return StaticArrayExtractor(self.access, sa_len, self(sa_typ))

View File

@ -1,38 +0,0 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
def test_io_use_type_class():
code_py = f"""
@exported
def testEntry() -> IO[u32]:
return 4
"""
result = Suite(code_py).run_code()
assert 4 == result.returned_value
@pytest.mark.integration_test
def test_io_call_io_function():
code_py = f"""
@imported
def log(val: u32) -> IO[()]:
pass
@exported
def testEntry() -> IO[u32]:
log(123)
return 4
"""
log_history: list[Any] = []
def my_log(val: int) -> None:
log_history.append(val)
result = Suite(code_py).run_code(imports={
'log': my_log,
})
assert 4 == result.returned_value
assert [123] == log_history