Removes the weird second step unify

It is now part of the normal constraints
This commit is contained in:
Johan B.W. de Vries 2025-08-23 16:22:51 +02:00
parent 3d504e3d79
commit 439ed97636
8 changed files with 182 additions and 212 deletions

View File

@ -8,6 +8,7 @@ from ..type5.typeexpr import (
TypeApplication, TypeApplication,
TypeConstructor, TypeConstructor,
TypeExpr, TypeExpr,
TypeLevelNat,
TypeVariable, TypeVariable,
) )
from ..type5.typerouter import TypeRouter from ..type5.typerouter import TypeRouter
@ -102,6 +103,9 @@ class TypeName(BuildTypeRouter[str]):
def when_tuple(self, tp_args: list[TypeExpr]) -> str: def when_tuple(self, tp_args: list[TypeExpr]) -> str:
return '(' + ', '.join(map(self, tp_args)) + ', )' return '(' + ', '.join(map(self, tp_args)) + ', )'
def when_type_level_nat(self, typ: TypeLevelNat) -> str:
return str(typ.value)
def when_variable(self, typ: TypeVariable) -> str: def when_variable(self, typ: TypeVariable) -> str:
return typ.name return typ.name

View File

@ -9,13 +9,15 @@ from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt6
from .kindexpr import KindExpr, Star from .kindexpr import KindExpr, Star
from .record import Record from .record import Record
from .typeexpr import ( from .typeexpr import (
AtomicType,
TypeApplication, TypeApplication,
TypeConstructor,
TypeExpr, TypeExpr,
TypeVariable, TypeVariable,
is_concrete, is_concrete,
occurs,
replace_variable, replace_variable,
) )
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
class ExpressionProtocol(Protocol): class ExpressionProtocol(Protocol):
@ -43,35 +45,51 @@ class Context:
self.placeholder_update[res] = arg self.placeholder_update[res] = arg
return res return res
@dataclasses.dataclass
class Failure:
"""
Both types are already different - cannot be unified.
"""
msg: str
@dataclasses.dataclass
class ReplaceVariable:
var: TypeVariable
typ: TypeExpr
@dataclasses.dataclass @dataclasses.dataclass
class CheckResult: class CheckResult:
# TODO: Refactor this, don't think we use most of the variants
_: dataclasses.KW_ONLY _: dataclasses.KW_ONLY
done: bool = True done: bool = True
actions: ActionList = dataclasses.field(default_factory=ActionList) replace: ReplaceVariable | None = None
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list) new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
failures: list[Failure] = dataclasses.field(default_factory=list) failures: list[Failure] = dataclasses.field(default_factory=list)
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
if not self.done and not self.actions and not self.new_constraints and not self.failures: if not self.done and not self.replace and not self.new_constraints and not self.failures:
return '(skip for now)' return '(skip for now)'
if self.done and not self.actions and not self.new_constraints and not self.failures: if self.done and not self.replace and not self.new_constraints and not self.failures:
return '(ok)' return '(ok)'
if self.done and self.actions and not self.new_constraints and not self.failures: if self.done and self.replace and not self.new_constraints and not self.failures:
return self.actions.to_str(type_namer) return f'{{{self.replace.var.name} := {type_namer(self.replace.typ)}}}'
if self.done and not self.actions and self.new_constraints and not self.failures: if self.done and not self.replace and self.new_constraints and not self.failures:
return f'(got {len(self.new_constraints)} new constraints)' return f'(got {len(self.new_constraints)} new constraints)'
if self.done and not self.actions and not self.new_constraints and self.failures: if self.done and not self.replace and not self.new_constraints and self.failures:
return 'ERR: ' + '; '.join(x.msg for x in self.failures) return 'ERR: ' + '; '.join(x.msg for x in self.failures)
return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}' return f'{self.done} {self.replace} {self.new_constraints} {self.failures}'
def skip_for_now() -> CheckResult: def skip_for_now() -> CheckResult:
return CheckResult(done=False) return CheckResult(done=False)
def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult:
return CheckResult(replace=ReplaceVariable(var, typ))
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult: def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
return CheckResult(new_constraints=list(lst)) return CheckResult(new_constraints=list(lst))
@ -94,12 +112,8 @@ class ConstraintBase:
def check(self) -> CheckResult: def check(self) -> CheckResult:
raise NotImplementedError(self) raise NotImplementedError(self)
def apply(self, action: Action) -> None: def apply(self, action: ReplaceVariable) -> None:
if isinstance(action, ReplaceVariable):
self.replace_variable(action.var, action.typ) self.replace_variable(action.var, action.typ)
return
raise NotImplementedError(action)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
pass pass
@ -216,12 +230,82 @@ class UnifyTypesConstraint(ConstraintBase):
self.rgt = rgt self.rgt = rgt
def check(self) -> CheckResult: def check(self) -> CheckResult:
result = unify(self.lft, self.rgt) lft = self.lft
rgt = self.rgt
if isinstance(result, Failure): if lft == self.rgt:
return CheckResult(failures=[result]) return ok()
return CheckResult(actions=result) if lft.kind != rgt.kind:
return fail("Kind mismatch")
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
return fail("Not the same type")
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
return replace(rgt, lft)
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
return fail("Not the same type" if is_concrete(rgt) else "Type shape mismatch")
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
if occurs(lft, rgt):
return fail("One type occurs in the other")
return replace(lft, rgt)
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
return replace(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
return fail("Not the same type constructor")
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
return fail("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
return fail("Not the same type" if is_concrete(lft) else "Type shape mismatch")
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
if occurs(rgt, lft):
return fail("One type occurs in the other")
return replace(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
return fail("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor, rgt.constructor),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
])
raise NotImplementedError(lft, rgt)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.lft = replace_variable(self.lft, var, typ) self.lft = replace_variable(self.lft, var, typ)

View File

@ -1,10 +1,9 @@
from typing import Any from typing import Any
from ..ourlang import Module from ..ourlang import Module
from .constraints import ConstraintBase, Context from .constraints import ConstraintBase, Context, ReplaceVariable
from .fromast import phasm_type5_generate_constraints from .fromast import phasm_type5_generate_constraints
from .typeexpr import TypeExpr, TypeVariable, replace_variable from .typeexpr import TypeExpr, TypeVariable, replace_variable
from .unify import ReplaceVariable
MAX_RESTACK_COUNT = 100 MAX_RESTACK_COUNT = 100
@ -44,29 +43,15 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
# Means it checks out and we don't need do anything # Means it checks out and we don't need do anything
continue continue
while result.actions: if result.replace is not None:
action = result.actions.pop(0) action_var = result.replace.var
assert action_var not in placeholder_types # When does this happen?
if isinstance(action, ReplaceVariable): action_typ: TypeExpr = result.replace.typ
action_var: TypeExpr = action.var assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen?
while isinstance(action_var, TypeVariable) and action_var in placeholder_types:
# TODO: Does this still happen?
action_var = placeholder_types[action_var]
action_typ: TypeExpr = action.typ assert action_var != action_typ # When does this happen?
while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types:
# TODO: Does this still happen?
action_typ = placeholder_types[action_typ]
# print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ))
if action_var == action_typ:
continue
if not isinstance(action_var, TypeVariable) and isinstance(action_typ, TypeVariable):
action_typ, action_var = action_var, action_typ
if isinstance(action_var, TypeVariable):
# Ensure all existing found types are updated # Ensure all existing found types are updated
placeholder_types = { placeholder_types = {
k: replace_variable(v, action_var, action_typ) k: replace_variable(v, action_var, action_typ)
@ -85,15 +70,6 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
if verbose and old_str != new_str: if verbose and old_str != new_str:
print(f"{oth_const.sourceref!s} => - {old_str!s}") print(f"{oth_const.sourceref!s} => - {old_str!s}")
print(f"{oth_const.sourceref!s} => + {new_str!s}") print(f"{oth_const.sourceref!s} => + {new_str!s}")
continue
error_list.append((str(constraint.sourceref), str(constraint), "Not the same type", ))
if verbose:
print(f"{constraint.sourceref!s} => ERR: Conflict in applying {action.to_str(inp.build.type5_name)}")
continue
# Action of unsupported type
raise NotImplementedError(action)
for failure in result.failures: for failure in result.failures:
error_list.append((str(constraint.sourceref), str(constraint), failure.msg, )) error_list.append((str(constraint.sourceref), str(constraint), failure.msg, ))

View File

@ -3,6 +3,7 @@ from .typeexpr import (
AtomicType, AtomicType,
TypeApplication, TypeApplication,
TypeConstructor, TypeConstructor,
TypeLevelNat,
TypeExpr, TypeExpr,
TypeVariable, TypeVariable,
) )
@ -21,6 +22,9 @@ class TypeRouter[T]:
def when_record(self, typ: Record) -> T: def when_record(self, typ: Record) -> T:
raise NotImplementedError(typ) raise NotImplementedError(typ)
def when_type_level_nat(self, typ: TypeLevelNat) -> T:
raise NotImplementedError(typ)
def when_variable(self, typ: TypeVariable) -> T: def when_variable(self, typ: TypeVariable) -> T:
raise NotImplementedError(typ) raise NotImplementedError(typ)
@ -37,6 +41,9 @@ class TypeRouter[T]:
if isinstance(typ, TypeConstructor): if isinstance(typ, TypeConstructor):
return self.when_constructor(typ) return self.when_constructor(typ)
if isinstance(typ, TypeLevelNat):
return self.when_type_level_nat(typ)
if isinstance(typ, TypeVariable): if isinstance(typ, TypeVariable):
return self.when_variable(typ) return self.when_variable(typ)

View File

@ -1,128 +0,0 @@
from dataclasses import dataclass
from typing import Callable
from .typeexpr import (
AtomicType,
TypeApplication,
TypeConstructor,
TypeExpr,
TypeVariable,
is_concrete,
occurs,
)
@dataclass
class Failure:
"""
Both types are already different - cannot be unified.
"""
msg: str
@dataclass
class Action:
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
raise NotImplementedError
class ActionList(list[Action]):
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return '{' + ', '.join((x.to_str(type_namer) for x in self)) + '}'
UnifyResult = Failure | ActionList
@dataclass
class ReplaceVariable(Action):
var: TypeVariable
typ: TypeExpr
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return f'{self.var.name} := {type_namer(self.typ)}'
def unify(lft: TypeExpr, rgt: TypeExpr) -> UnifyResult:
"""
Be warned: This only matches type variables with other variables or types
- it does not apply substituions nor does it validate if the matching
pairs are correct.
TODO: Remove this. It should be part of UnifyTypesConstraint
and should just generate new constraints for applications.
"""
if lft == rgt:
return ActionList()
if lft.kind != rgt.kind:
return Failure("Kind mismatch")
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
return Failure("Not the same type")
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
return ActionList([ReplaceVariable(rgt, lft)])
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
if is_concrete(rgt):
return Failure("Not the same type")
return Failure("Type shape mismatch")
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
if occurs(lft, rgt):
return Failure("One type occurs in the other")
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
return unify(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
return Failure("Not the same type constructor")
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
return Failure("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
con_res = unify(lft.constructor, rgt.constructor)
if isinstance(con_res, Failure):
return con_res
arg_res = unify(lft.argument, rgt.argument)
if isinstance(arg_res, Failure):
return arg_res
return ActionList(con_res + arg_res)
return Failure('Not implemented')

View File

@ -305,8 +305,5 @@ def testEntry() -> i32:
``` ```
```py ```py
if TYPE_NAME.startswith('tuple_') or TYPE_NAME.startswith('static_array_') or TYPE_NAME.startswith('dynamic_array_'): expect_type_error('Not the same type')
expect_type_error('Not the same type constructor')
else:
expect_type_error('Not the same type')
``` ```

View File

@ -91,8 +91,7 @@ def testEntry() -> i32:
return action(double, 13.0) return action(double, 13.0)
""" """
match = r'Callable\[i32, i32\] ~ Callable\[f32, [^]]+\]' with pytest.raises(Type5SolverException, match='f32 ~ i32'):
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
@ -109,8 +108,7 @@ def testEntry() -> i32:
return action(double, 13) return action(double, 13)
""" """
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[f32, i32\], p_[0-9]+, [^]]+\]' with pytest.raises(Type5SolverException, match='f32 ~ i32'):
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
@ -144,8 +142,7 @@ def testEntry() -> i32:
return action(double, 13) return action(double, 13)
""" """
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[i32, f32\], p_[0-9]+, [^]]+\]' with pytest.raises(Type5SolverException, match='f32 ~ i32'):
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
@ -179,12 +176,12 @@ def testEntry() -> i32:
return action(double, 13, 14) return action(double, 13, 14)
""" """
match = r'Callable\[Callable\[i32, i32, i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]' match = r'i32 ~ Callable\[i32, i32\]'
with pytest.raises(Type5SolverException, match=match): with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
def test_sof_too_many_args_use(): def test_sof_too_many_args_use_0():
code_py = """ code_py = """
def thirteen() -> i32: def thirteen() -> i32:
return 13 return 13
@ -197,12 +194,30 @@ def testEntry() -> i32:
return action(thirteen, 13) return action(thirteen, 13)
""" """
match = r'Callable\[i32\] ~ Callable\[i32, p_[0-9]+\]' match = r'i32 ~ \(\)'
with pytest.raises(Type5SolverException, match=match): with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code(verbose=True) Suite(code_py).run_code(verbose=True)
@pytest.mark.integration_test @pytest.mark.integration_test
def test_sof_too_many_args_pass(): def test_sof_too_many_args_use_1():
code_py = """
def thirteen(x: i32) -> i32:
return x
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(thirteen, 13, 26)
"""
match = r'i32 ~ Callable\[i32, i32\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code(verbose=True)
@pytest.mark.integration_test
def test_sof_too_many_args_pass_0():
code_py = """ code_py = """
def double(left: i32) -> i32: def double(left: i32) -> i32:
return left * 2 return left * 2
@ -215,6 +230,24 @@ def testEntry() -> i32:
return action(double, 13, 14) return action(double, 13, 14)
""" """
match = r'Callable\[Callable\[i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]' match = r'i32 ~ \(\)'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_too_many_args_pass_1():
code_py = """
def double(left: i32, right: i32) -> i32:
return left * right
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13, 14)
"""
match = r'i32 ~ Callable\[i32, i32\]'
with pytest.raises(Type5SolverException, match=match): with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code() Suite(code_py).run_code()

View File

@ -168,12 +168,9 @@ def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
return foldl(x, y, z) return foldl(x, y, z)
""" """
match = { match = 'Type shape mismatch'
'i8': 'Type shape mismatch',
'i8[3]': 'Kind mismatch',
}
with pytest.raises(Type5SolverException, match=match[in_typ]): with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test