diff --git a/phasm/build/typerouter.py b/phasm/build/typerouter.py index 44cfbf5..19e946a 100644 --- a/phasm/build/typerouter.py +++ b/phasm/build/typerouter.py @@ -8,6 +8,7 @@ from ..type5.typeexpr import ( TypeApplication, TypeConstructor, TypeExpr, + TypeLevelNat, TypeVariable, ) from ..type5.typerouter import TypeRouter @@ -102,6 +103,9 @@ class TypeName(BuildTypeRouter[str]): def when_tuple(self, tp_args: list[TypeExpr]) -> str: return '(' + ', '.join(map(self, tp_args)) + ', )' + def when_type_level_nat(self, typ: TypeLevelNat) -> str: + return str(typ.value) + def when_variable(self, typ: TypeVariable) -> str: return typ.name diff --git a/phasm/type5/constraints.py b/phasm/type5/constraints.py index 657b049..5a76756 100644 --- a/phasm/type5/constraints.py +++ b/phasm/type5/constraints.py @@ -9,13 +9,15 @@ from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt6 from .kindexpr import KindExpr, Star from .record import Record from .typeexpr import ( + AtomicType, TypeApplication, + TypeConstructor, TypeExpr, TypeVariable, is_concrete, + occurs, replace_variable, ) -from .unify import Action, ActionList, Failure, ReplaceVariable, unify class ExpressionProtocol(Protocol): @@ -43,35 +45,51 @@ class Context: self.placeholder_update[res] = arg return res +@dataclasses.dataclass +class Failure: + """ + Both types are already different - cannot be unified. + """ + msg: str + +@dataclasses.dataclass +class ReplaceVariable: + var: TypeVariable + typ: TypeExpr + @dataclasses.dataclass class CheckResult: + # TODO: Refactor this, don't think we use most of the variants _: dataclasses.KW_ONLY done: bool = True - actions: ActionList = dataclasses.field(default_factory=ActionList) + replace: ReplaceVariable | None = None new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list) failures: list[Failure] = dataclasses.field(default_factory=list) def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: - if not self.done and not self.actions and not self.new_constraints and not self.failures: + if not self.done and not self.replace and not self.new_constraints and not self.failures: return '(skip for now)' - if self.done and not self.actions and not self.new_constraints and not self.failures: + if self.done and not self.replace and not self.new_constraints and not self.failures: return '(ok)' - if self.done and self.actions and not self.new_constraints and not self.failures: - return self.actions.to_str(type_namer) + if self.done and self.replace and not self.new_constraints and not self.failures: + return f'{{{self.replace.var.name} := {type_namer(self.replace.typ)}}}' - if self.done and not self.actions and self.new_constraints and not self.failures: + if self.done and not self.replace and self.new_constraints and not self.failures: return f'(got {len(self.new_constraints)} new constraints)' - if self.done and not self.actions and not self.new_constraints and self.failures: + if self.done and not self.replace and not self.new_constraints and self.failures: return 'ERR: ' + '; '.join(x.msg for x in self.failures) - return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}' + return f'{self.done} {self.replace} {self.new_constraints} {self.failures}' def skip_for_now() -> CheckResult: return CheckResult(done=False) +def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult: + return CheckResult(replace=ReplaceVariable(var, typ)) + def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult: return CheckResult(new_constraints=list(lst)) @@ -94,12 +112,8 @@ class ConstraintBase: def check(self) -> CheckResult: raise NotImplementedError(self) - def apply(self, action: Action) -> None: - if isinstance(action, ReplaceVariable): - self.replace_variable(action.var, action.typ) - return - - raise NotImplementedError(action) + def apply(self, action: ReplaceVariable) -> None: + self.replace_variable(action.var, action.typ) def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: pass @@ -216,12 +230,82 @@ class UnifyTypesConstraint(ConstraintBase): self.rgt = rgt def check(self) -> CheckResult: - result = unify(self.lft, self.rgt) + lft = self.lft + rgt = self.rgt - if isinstance(result, Failure): - return CheckResult(failures=[result]) + if lft == self.rgt: + return ok() - return CheckResult(actions=result) + if lft.kind != rgt.kind: + return fail("Kind mismatch") + + + + if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType): + return fail("Not the same type") + + if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable): + return replace(rgt, lft) + + if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor): + raise NotImplementedError # Should have been caught by kind check above + + if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication): + return fail("Not the same type" if is_concrete(rgt) else "Type shape mismatch") + + + + if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType): + return replace(lft, rgt) + + if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable): + return replace(lft, rgt) + + if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor): + return replace(lft, rgt) + + if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication): + if occurs(lft, rgt): + return fail("One type occurs in the other") + + return replace(lft, rgt) + + + + if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType): + raise NotImplementedError # Should have been caught by kind check above + + if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable): + return replace(rgt, lft) + + if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor): + return fail("Not the same type constructor") + + if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication): + return fail("Not the same type constructor") + + + + if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType): + return fail("Not the same type" if is_concrete(lft) else "Type shape mismatch") + + if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable): + if occurs(rgt, lft): + return fail("One type occurs in the other") + + return replace(rgt, lft) + + if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor): + return fail("Not the same type constructor") + + if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication): + return new_constraints([ + UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor, rgt.constructor), + UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument), + ]) + + + raise NotImplementedError(lft, rgt) def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: self.lft = replace_variable(self.lft, var, typ) diff --git a/phasm/type5/solver.py b/phasm/type5/solver.py index e4a4c1b..14d92d8 100644 --- a/phasm/type5/solver.py +++ b/phasm/type5/solver.py @@ -1,10 +1,9 @@ from typing import Any from ..ourlang import Module -from .constraints import ConstraintBase, Context +from .constraints import ConstraintBase, Context, ReplaceVariable from .fromast import phasm_type5_generate_constraints from .typeexpr import TypeExpr, TypeVariable, replace_variable -from .unify import ReplaceVariable MAX_RESTACK_COUNT = 100 @@ -44,56 +43,33 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None: # Means it checks out and we don't need do anything continue - while result.actions: - action = result.actions.pop(0) + if result.replace is not None: + action_var = result.replace.var + assert action_var not in placeholder_types # When does this happen? - if isinstance(action, ReplaceVariable): - action_var: TypeExpr = action.var - while isinstance(action_var, TypeVariable) and action_var in placeholder_types: - # TODO: Does this still happen? - action_var = placeholder_types[action_var] + action_typ: TypeExpr = result.replace.typ + assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen? - action_typ: TypeExpr = action.typ - while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types: - # TODO: Does this still happen? - action_typ = placeholder_types[action_typ] + assert action_var != action_typ # When does this happen? - # print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ)) + # Ensure all existing found types are updated + placeholder_types = { + k: replace_variable(v, action_var, action_typ) + for k, v in placeholder_types.items() + } + placeholder_types[action_var] = action_typ - if action_var == action_typ: + for oth_const in new_constraint_list + constraint_list: + if oth_const is constraint and result.done: continue - if not isinstance(action_var, TypeVariable) and isinstance(action_typ, TypeVariable): - action_typ, action_var = action_var, action_typ + old_str = str(oth_const) + oth_const.replace_variable(action_var, action_typ) + new_str = str(oth_const) - if isinstance(action_var, TypeVariable): - # Ensure all existing found types are updated - placeholder_types = { - k: replace_variable(v, action_var, action_typ) - for k, v in placeholder_types.items() - } - placeholder_types[action_var] = action_typ - - for oth_const in new_constraint_list + constraint_list: - if oth_const is constraint and result.done: - continue - - old_str = str(oth_const) - oth_const.replace_variable(action_var, action_typ) - new_str = str(oth_const) - - if verbose and old_str != new_str: - print(f"{oth_const.sourceref!s} => - {old_str!s}") - print(f"{oth_const.sourceref!s} => + {new_str!s}") - continue - - error_list.append((str(constraint.sourceref), str(constraint), "Not the same type", )) - if verbose: - print(f"{constraint.sourceref!s} => ERR: Conflict in applying {action.to_str(inp.build.type5_name)}") - continue - - # Action of unsupported type - raise NotImplementedError(action) + if verbose and old_str != new_str: + print(f"{oth_const.sourceref!s} => - {old_str!s}") + print(f"{oth_const.sourceref!s} => + {new_str!s}") for failure in result.failures: error_list.append((str(constraint.sourceref), str(constraint), failure.msg, )) diff --git a/phasm/type5/typerouter.py b/phasm/type5/typerouter.py index 1609b36..11a5f5b 100644 --- a/phasm/type5/typerouter.py +++ b/phasm/type5/typerouter.py @@ -3,6 +3,7 @@ from .typeexpr import ( AtomicType, TypeApplication, TypeConstructor, + TypeLevelNat, TypeExpr, TypeVariable, ) @@ -21,6 +22,9 @@ class TypeRouter[T]: def when_record(self, typ: Record) -> T: raise NotImplementedError(typ) + def when_type_level_nat(self, typ: TypeLevelNat) -> T: + raise NotImplementedError(typ) + def when_variable(self, typ: TypeVariable) -> T: raise NotImplementedError(typ) @@ -37,6 +41,9 @@ class TypeRouter[T]: if isinstance(typ, TypeConstructor): return self.when_constructor(typ) + if isinstance(typ, TypeLevelNat): + return self.when_type_level_nat(typ) + if isinstance(typ, TypeVariable): return self.when_variable(typ) diff --git a/phasm/type5/unify.py b/phasm/type5/unify.py deleted file mode 100644 index a8530e4..0000000 --- a/phasm/type5/unify.py +++ /dev/null @@ -1,128 +0,0 @@ -from dataclasses import dataclass -from typing import Callable - -from .typeexpr import ( - AtomicType, - TypeApplication, - TypeConstructor, - TypeExpr, - TypeVariable, - is_concrete, - occurs, -) - - -@dataclass -class Failure: - """ - Both types are already different - cannot be unified. - """ - msg: str - -@dataclass -class Action: - def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: - raise NotImplementedError - -class ActionList(list[Action]): - def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: - return '{' + ', '.join((x.to_str(type_namer) for x in self)) + '}' - -UnifyResult = Failure | ActionList - -@dataclass -class ReplaceVariable(Action): - var: TypeVariable - typ: TypeExpr - - def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: - return f'{self.var.name} := {type_namer(self.typ)}' - -def unify(lft: TypeExpr, rgt: TypeExpr) -> UnifyResult: - """ - Be warned: This only matches type variables with other variables or types - - it does not apply substituions nor does it validate if the matching - pairs are correct. - - TODO: Remove this. It should be part of UnifyTypesConstraint - and should just generate new constraints for applications. - """ - if lft == rgt: - return ActionList() - - if lft.kind != rgt.kind: - return Failure("Kind mismatch") - - - - if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType): - return Failure("Not the same type") - - if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable): - return ActionList([ReplaceVariable(rgt, lft)]) - - if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor): - raise NotImplementedError # Should have been caught by kind check above - - if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication): - if is_concrete(rgt): - return Failure("Not the same type") - - return Failure("Type shape mismatch") - - - - if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType): - return unify(rgt, lft) - - if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable): - return ActionList([ReplaceVariable(lft, rgt)]) - - if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor): - return ActionList([ReplaceVariable(lft, rgt)]) - - if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication): - if occurs(lft, rgt): - return Failure("One type occurs in the other") - - return ActionList([ReplaceVariable(lft, rgt)]) - - - - if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType): - return unify(rgt, lft) - - if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable): - return unify(rgt, lft) - - if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor): - return Failure("Not the same type constructor") - - if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication): - return Failure("Not the same type constructor") - - - - if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType): - return unify(rgt, lft) - - if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable): - return unify(rgt, lft) - - if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor): - return unify(rgt, lft) - - if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication): - con_res = unify(lft.constructor, rgt.constructor) - if isinstance(con_res, Failure): - return con_res - - arg_res = unify(lft.argument, rgt.argument) - if isinstance(arg_res, Failure): - return arg_res - - return ActionList(con_res + arg_res) - - - - return Failure('Not implemented') diff --git a/tests/integration/test_lang/generator.md b/tests/integration/test_lang/generator.md index d70f653..a74e83b 100644 --- a/tests/integration/test_lang/generator.md +++ b/tests/integration/test_lang/generator.md @@ -305,8 +305,5 @@ def testEntry() -> i32: ``` ```py -if TYPE_NAME.startswith('tuple_') or TYPE_NAME.startswith('static_array_') or TYPE_NAME.startswith('dynamic_array_'): - expect_type_error('Not the same type constructor') -else: - expect_type_error('Not the same type') +expect_type_error('Not the same type') ``` diff --git a/tests/integration/test_lang/test_second_order_functions.py b/tests/integration/test_lang/test_second_order_functions.py index a598ce4..8e721fb 100644 --- a/tests/integration/test_lang/test_second_order_functions.py +++ b/tests/integration/test_lang/test_second_order_functions.py @@ -91,8 +91,7 @@ def testEntry() -> i32: return action(double, 13.0) """ - match = r'Callable\[i32, i32\] ~ Callable\[f32, [^]]+\]' - with pytest.raises(Type5SolverException, match=match): + with pytest.raises(Type5SolverException, match='f32 ~ i32'): Suite(code_py).run_code() @pytest.mark.integration_test @@ -109,8 +108,7 @@ def testEntry() -> i32: return action(double, 13) """ - match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[f32, i32\], p_[0-9]+, [^]]+\]' - with pytest.raises(Type5SolverException, match=match): + with pytest.raises(Type5SolverException, match='f32 ~ i32'): Suite(code_py).run_code() @pytest.mark.integration_test @@ -144,8 +142,7 @@ def testEntry() -> i32: return action(double, 13) """ - match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[i32, f32\], p_[0-9]+, [^]]+\]' - with pytest.raises(Type5SolverException, match=match): + with pytest.raises(Type5SolverException, match='f32 ~ i32'): Suite(code_py).run_code() @pytest.mark.integration_test @@ -179,12 +176,12 @@ def testEntry() -> i32: return action(double, 13, 14) """ - match = r'Callable\[Callable\[i32, i32, i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]' + match = r'i32 ~ Callable\[i32, i32\]' with pytest.raises(Type5SolverException, match=match): Suite(code_py).run_code() @pytest.mark.integration_test -def test_sof_too_many_args_use(): +def test_sof_too_many_args_use_0(): code_py = """ def thirteen() -> i32: return 13 @@ -197,12 +194,30 @@ def testEntry() -> i32: return action(thirteen, 13) """ - match = r'Callable\[i32\] ~ Callable\[i32, p_[0-9]+\]' + match = r'i32 ~ \(\)' with pytest.raises(Type5SolverException, match=match): Suite(code_py).run_code(verbose=True) @pytest.mark.integration_test -def test_sof_too_many_args_pass(): +def test_sof_too_many_args_use_1(): + code_py = """ +def thirteen(x: i32) -> i32: + return x + +def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32: + return applicable(left, right) + +@exported +def testEntry() -> i32: + return action(thirteen, 13, 26) +""" + + match = r'i32 ~ Callable\[i32, i32\]' + with pytest.raises(Type5SolverException, match=match): + Suite(code_py).run_code(verbose=True) + +@pytest.mark.integration_test +def test_sof_too_many_args_pass_0(): code_py = """ def double(left: i32) -> i32: return left * 2 @@ -215,6 +230,24 @@ def testEntry() -> i32: return action(double, 13, 14) """ - match = r'Callable\[Callable\[i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]' + match = r'i32 ~ \(\)' + with pytest.raises(Type5SolverException, match=match): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_sof_too_many_args_pass_1(): + code_py = """ +def double(left: i32, right: i32) -> i32: + return left * right + +def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32: + return applicable(left) + +@exported +def testEntry() -> i32: + return action(double, 13, 14) +""" + + match = r'i32 ~ Callable\[i32, i32\]' with pytest.raises(Type5SolverException, match=match): Suite(code_py).run_code() diff --git a/tests/integration/test_typeclasses/test_foldable.py b/tests/integration/test_typeclasses/test_foldable.py index d516ae5..476c09a 100644 --- a/tests/integration/test_typeclasses/test_foldable.py +++ b/tests/integration/test_typeclasses/test_foldable.py @@ -168,12 +168,9 @@ def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32: return foldl(x, y, z) """ - match = { - 'i8': 'Type shape mismatch', - 'i8[3]': 'Kind mismatch', - } + match = 'Type shape mismatch' - with pytest.raises(Type5SolverException, match=match[in_typ]): + with pytest.raises(Type5SolverException, match=match): Suite(code_py).run_code() @pytest.mark.integration_test