diff --git a/phasm/build/typeclasses/sized.py b/phasm/build/typeclasses/sized.py index 9d39d6a..7235bd8 100644 --- a/phasm/build/typeclasses/sized.py +++ b/phasm/build/typeclasses/sized.py @@ -38,7 +38,6 @@ def load(build: BuildBase[Any]) -> None: build.register_type_class(Sized) def wasm_dynamic_array_len(g: WasmGenerator, tv_map: dict[str, TypeExpr]) -> None: - print('tv_map', tv_map) del tv_map # The length is stored in the first 4 bytes g.i32.load() diff --git a/phasm/build/typerouter.py b/phasm/build/typerouter.py index 44cfbf5..19e946a 100644 --- a/phasm/build/typerouter.py +++ b/phasm/build/typerouter.py @@ -8,6 +8,7 @@ from ..type5.typeexpr import ( TypeApplication, TypeConstructor, TypeExpr, + TypeLevelNat, TypeVariable, ) from ..type5.typerouter import TypeRouter @@ -102,6 +103,9 @@ class TypeName(BuildTypeRouter[str]): def when_tuple(self, tp_args: list[TypeExpr]) -> str: return '(' + ', '.join(map(self, tp_args)) + ', )' + def when_type_level_nat(self, typ: TypeLevelNat) -> str: + return str(typ.value) + def when_variable(self, typ: TypeVariable) -> str: return typ.name diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 965b345..cbd8cb7 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -67,7 +67,7 @@ def expression(inp: ourlang.Expression) -> str: return str(inp.variable.name) if isinstance(inp, ourlang.BinaryOp): - return f'{expression(inp.left)} {inp.operator.function.name} {expression(inp.right)}' + return f'{expression(inp.left)} {inp.operator.name} {expression(inp.right)}' if isinstance(inp, ourlang.FunctionCall): args = ', '.join( @@ -75,10 +75,10 @@ def expression(inp: ourlang.Expression) -> str: for arg in inp.arguments ) - if isinstance(inp.function_instance.function, ourlang.StructConstructor): - return f'{inp.function_instance.function.struct_type5.name}({args})' + if isinstance(inp.function, ourlang.StructConstructor): + return f'{inp.function.struct_type5.name}({args})' - return f'{inp.function_instance.function.name}({args})' + return f'{inp.function.name}({args})' if isinstance(inp, ourlang.FunctionReference): return str(inp.function.name) diff --git a/phasm/compiler.py b/phasm/compiler.py index 3aafd1d..6d34841 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -11,7 +11,14 @@ from .build.typerouter import BuildTypeRouter from .stdlib import alloc as stdlib_alloc from .stdlib import types as stdlib_types from .type5.constrainedexpr import ConstrainedExpr -from .type5.typeexpr import AtomicType, TypeApplication, TypeExpr, is_concrete +from .type5.typeexpr import ( + AtomicType, + TypeApplication, + TypeExpr, + TypeVariable, + is_concrete, + replace_variable, +) from .wasm import ( WasmTypeFloat32, WasmTypeFloat64, @@ -148,32 +155,94 @@ def expression_subscript_tuple(wgn: WasmGenerator, mod: ourlang.Module[WasmGener expression(wgn, mod, inp.varref) wgn.add_statement(el_type_info.wasm_load_func, f'offset={offset}') +def expression_subscript_operator(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Subscript) -> None: + assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR + + ftp5 = mod.build.type_classes['Subscriptable'].operators['[]'] + fn_args = mod.build.type5_is_function(ftp5) + assert fn_args is not None + t_a = fn_args[0] + assert isinstance(t_a, TypeApplication) + t = t_a.constructor + a = t_a.argument + + assert isinstance(t, TypeVariable) + assert isinstance(a, TypeVariable) + + assert isinstance(inp.varref.type5, TypeApplication) + t_expr = inp.varref.type5.constructor + a_expr = inp.varref.type5.argument + + _expression_binary_operator_or_function_call( + wgn, + mod, + ourlang.BuiltinFunction('[]', ftp5), + { + t: t_expr, + a: a_expr, + }, + [inp.varref, inp.index], + inp.type5, + ) + def expression_binary_op(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.BinaryOp) -> None: - expression_function_call(wgn, mod, _binary_op_to_function(inp)) + assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR + + _expression_binary_operator_or_function_call( + wgn, + mod, + inp.operator, + inp.polytype_substitutions, + [inp.left, inp.right], + inp.type5, + ) def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.FunctionCall) -> None: - for arg in inp.arguments: + assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR + + _expression_binary_operator_or_function_call( + wgn, + mod, + inp.function, + inp.polytype_substitutions, + inp.arguments, + inp.type5, + ) + +def _expression_binary_operator_or_function_call( + wgn: WasmGenerator, + mod: ourlang.Module[WasmGenerator], + function: ourlang.Function | ourlang.FunctionParam, + polytype_substitutions: dict[TypeVariable, TypeExpr], + arguments: list[ourlang.Expression], + ret_type5: TypeExpr, +) -> None: + for arg in arguments: expression(wgn, mod, arg) - if isinstance(inp.function_instance.function, ourlang.BuiltinFunction): - assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR + if isinstance(function, ourlang.BuiltinFunction): + ftp5 = function.type5 + if isinstance(ftp5, ConstrainedExpr): + cexpr = ftp5 + ftp5 = ftp5.expr + for tvar in cexpr.variables: + ftp5 = replace_variable(ftp5, tvar, polytype_substitutions[tvar]) + assert _is_concrete(ftp5), TYPE5_ASSERTION_ERROR try: - method_type, method_router = mod.build.methods[inp.function_instance.function.name] + method_type, method_router = mod.build.methods[function.name] except KeyError: - method_type, method_router = mod.build.operators[inp.function_instance.function.name] + method_type, method_router = mod.build.operators[function.name] - impl_lookup = method_router.get((inp.function_instance.type5, )) - assert impl_lookup is not None, (inp.function_instance.function.name, inp.function_instance.type5, ) + impl_lookup = method_router.get((ftp5, )) + assert impl_lookup is not None, (function.name, ftp5, ) kwargs, impl = impl_lookup impl(wgn, kwargs) return - if isinstance(inp.function_instance.function, ourlang.FunctionParam): - assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR - - fn_args = mod.build.type5_is_function(inp.function_instance.type5) - assert fn_args is not None + if isinstance(function, ourlang.FunctionParam): + fn_args = mod.build.type5_is_function(function.type5) + assert fn_args is not None, function.type5 params = [ type5(mod, x) @@ -182,11 +251,15 @@ def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerat result = params.pop() - wgn.add_statement('local.get', '${}'.format(inp.function_instance.function.name)) + wgn.add_statement('local.get', '${}'.format(function.name)) wgn.call_indirect(params=params, result=result) return - wgn.call(inp.function_instance.function.name) + # TODO: Do similar subsitutions like we do for BuiltinFunction + # when we get user space polymorphic functions + # And then do similar lookup, and ensure we generate code for that variant + + wgn.call(function.name) def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Expression) -> None: """ @@ -278,22 +351,7 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourl expression_subscript_tuple(wgn, mod, inp) return - inp_as_fc = ourlang.FunctionCall( - ourlang.FunctionInstance( - ourlang.BuiltinFunction('[]', mod.build.type_classes['Subscriptable'].operators['[]']), - inp.sourceref, - ), - inp.sourceref, - ) - inp_as_fc.arguments = [inp.varref, inp.index] - inp_as_fc.function_instance.type5 = mod.build.type5_make_function([ - inp.varref.type5, - inp.index.type5, - inp.type5, - ]) - inp_as_fc.type5 = inp.type5 - - expression_function_call(wgn, mod, inp_as_fc) + expression_subscript_operator(wgn, mod, inp) return if isinstance(inp, ourlang.AccessStructMember): @@ -321,11 +379,11 @@ def statement_return(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun # Support tail calls # https://github.com/WebAssembly/tail-call # These help a lot with some functional programming techniques - if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function_instance.function is fun: + if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function is fun: for arg in inp.value.arguments: expression(wgn, mod, arg) - wgn.add_statement('return_call', '${}'.format(inp.value.function_instance.function.name)) + wgn.add_statement('return_call', '${}'.format(inp.value.function.name)) return expression(wgn, mod, inp.value) @@ -595,14 +653,3 @@ def _type5_struct_offset( result += build.type5_alloc_size_member(memtyp) raise RuntimeError('Member not found') - -def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall: - """ - For compilation purposes, a binary operator is just a function call. - - It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)` - """ - assert inp.sourceref is not None # TODO: sourceref required - call = ourlang.FunctionCall(inp.operator, inp.sourceref) - call.arguments = [inp.left, inp.right] - return call diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 698f5b0..75ed7c6 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -157,52 +157,39 @@ class BinaryOp(Expression): """ A binary operator expression within a statement """ - __slots__ = ('operator', 'left', 'right', ) + __slots__ = ('operator', 'polytype_substitutions', 'left', 'right', ) - operator: FunctionInstance + operator: Function | FunctionParam + polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr] left: Expression right: Expression - def __init__(self, operator: FunctionInstance, left: Expression, right: Expression, sourceref: SourceRef) -> None: + def __init__(self, operator: Function | FunctionParam, left: Expression, right: Expression, sourceref: SourceRef) -> None: super().__init__(sourceref=sourceref) self.operator = operator + self.polytype_substitutions = {} self.left = left self.right = right def __repr__(self) -> str: return f'BinaryOp({repr(self.operator)}, {repr(self.left)}, {repr(self.right)})' -class FunctionInstance(Expression): - """ - When calling a polymorphic function with concrete arguments, we can generate - code for that specific instance of the function. - """ - __slots__ = ('function', ) - - function: Union['Function', 'FunctionParam'] - - def __init__(self, function: Union['Function', 'FunctionParam'], sourceref: SourceRef) -> None: - super().__init__(sourceref=sourceref) - - self.function = function - class FunctionCall(Expression): """ A function call expression within a statement """ - __slots__ = ('function_instance', 'arguments', ) + __slots__ = ('function', 'polytype_substitutions', 'arguments', ) - function_instance: FunctionInstance - # TODO: FunctionInstance is wrong - we should have - # substitutions: dict[TypeVariable, TypeExpr] - # And it should have the same variables as the polytype (ConstrainedExpr) for function + function: Function | FunctionParam + polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr] arguments: List[Expression] - def __init__(self, function_instance: FunctionInstance, sourceref: SourceRef) -> None: + def __init__(self, function: Function | FunctionParam, sourceref: SourceRef) -> None: super().__init__(sourceref=sourceref) - self.function_instance = function_instance + self.function = function + self.polytype_substitutions = {} self.arguments = [] class FunctionReference(Expression): diff --git a/phasm/parser.py b/phasm/parser.py index 151747f..adb43ed 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -18,7 +18,6 @@ from .ourlang import ( Expression, Function, FunctionCall, - FunctionInstance, FunctionParam, FunctionReference, Module, @@ -400,7 +399,7 @@ class OurVisitor[G]: raise NotImplementedError(f'Operator {operator}') return BinaryOp( - FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)), + BuiltinFunction(operator, module.operators[operator]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right), srf(module, node), @@ -429,7 +428,7 @@ class OurVisitor[G]: raise NotImplementedError(f'Operator {operator}') return BinaryOp( - FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)), + BuiltinFunction(operator, module.operators[operator]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]), srf(module, node), @@ -506,7 +505,7 @@ class OurVisitor[G]: func = module.functions[node.func.id] - result = FunctionCall(FunctionInstance(func, srf(module, node)), sourceref=srf(module, node)) + result = FunctionCall(func, sourceref=srf(module, node)) result.arguments.extend( self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) for arg_expr in node.args diff --git a/phasm/type5/constrainedexpr.py b/phasm/type5/constrainedexpr.py index 5199930..14d42ff 100644 --- a/phasm/type5/constrainedexpr.py +++ b/phasm/type5/constrainedexpr.py @@ -42,9 +42,8 @@ class ConstrainedExpr: def instantiate_constrained( constrainedexpr: ConstrainedExpr, - known_map: dict[TypeVariable, TypeVariable], make_variable: Callable[[KindExpr, str], TypeVariable], - ) -> ConstrainedExpr: + ) -> tuple[ConstrainedExpr, dict[TypeVariable, TypeVariable]]: """ Instantiates a type expression and its constraints """ @@ -61,4 +60,4 @@ def instantiate_constrained( x.instantiate(known_map) for x in constrainedexpr.constraints ) - return ConstrainedExpr(constrainedexpr.variables, expr, constraints) + return ConstrainedExpr(constrainedexpr.variables, expr, constraints), known_map diff --git a/phasm/type5/constraints.py b/phasm/type5/constraints.py index 657b049..f95d817 100644 --- a/phasm/type5/constraints.py +++ b/phasm/type5/constraints.py @@ -1,7 +1,7 @@ from __future__ import annotations import dataclasses -from typing import Any, Callable, Iterable, Protocol, Sequence +from typing import Any, Callable, Iterable, Protocol, Sequence, TypeAlias from ..build.base import BuildBase from ..ourlang import SourceRef @@ -9,13 +9,16 @@ from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt6 from .kindexpr import KindExpr, Star from .record import Record from .typeexpr import ( + AtomicType, TypeApplication, + TypeConstructor, TypeExpr, + TypeLevelNat, TypeVariable, is_concrete, + occurs, replace_variable, ) -from .unify import Action, ActionList, Failure, ReplaceVariable, unify class ExpressionProtocol(Protocol): @@ -28,50 +31,95 @@ class ExpressionProtocol(Protocol): The type to update """ +PolytypeSubsituteMap: TypeAlias = dict[TypeVariable, TypeExpr] + class Context: - __slots__ = ("build", "placeholder_update", ) + __slots__ = ("build", "placeholder_update", "ptst_update", ) build: BuildBase[Any] placeholder_update: dict[TypeVariable, ExpressionProtocol | None] + ptst_update: dict[TypeVariable, tuple[PolytypeSubsituteMap, TypeVariable]] def __init__(self, build: BuildBase[Any]) -> None: self.build = build self.placeholder_update = {} + self.ptst_update = {} def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star(), prefix: str = 'p') -> TypeVariable: res = TypeVariable(kind, f"{prefix}_{len(self.placeholder_update)}") self.placeholder_update[res] = arg return res + def register_polytype_subsitutes(self, tvar: TypeVariable, arg: PolytypeSubsituteMap, orig_var: TypeVariable) -> None: + """ + When `tvar` gets subsituted, also set the result in arg with orig_var as key + + e.g. + + (-) :: Callable[a, a, a] + + def foo() -> u32: + return 2 - 1 + + During typing, we instantiate a into a_3, and get the following constraints: + - u8 ~ p_1 + - u8 ~ p_2 + - Exists NatNum a_3 + - Callable[a_3, a_3, a_3] ~ Callable[p_1, p_2, p_0] + - u8 ~ p_0 + + When we resolve a_3, then on the call to `-`, we should note that a_3 got resolved + to u32. But we need to use `a` as key, since that's what's used on the definition + """ + assert tvar in self.placeholder_update + assert tvar not in self.ptst_update + self.ptst_update[tvar] = (arg, orig_var) + +@dataclasses.dataclass +class Failure: + """ + Both types are already different - cannot be unified. + """ + msg: str + +@dataclasses.dataclass +class ReplaceVariable: + var: TypeVariable + typ: TypeExpr + @dataclasses.dataclass class CheckResult: + # TODO: Refactor this, don't think we use most of the variants _: dataclasses.KW_ONLY done: bool = True - actions: ActionList = dataclasses.field(default_factory=ActionList) + replace: ReplaceVariable | None = None new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list) failures: list[Failure] = dataclasses.field(default_factory=list) def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: - if not self.done and not self.actions and not self.new_constraints and not self.failures: + if not self.done and not self.replace and not self.new_constraints and not self.failures: return '(skip for now)' - if self.done and not self.actions and not self.new_constraints and not self.failures: + if self.done and not self.replace and not self.new_constraints and not self.failures: return '(ok)' - if self.done and self.actions and not self.new_constraints and not self.failures: - return self.actions.to_str(type_namer) + if self.done and self.replace and not self.new_constraints and not self.failures: + return f'{{{self.replace.var.name} := {type_namer(self.replace.typ)}}}' - if self.done and not self.actions and self.new_constraints and not self.failures: + if self.done and not self.replace and self.new_constraints and not self.failures: return f'(got {len(self.new_constraints)} new constraints)' - if self.done and not self.actions and not self.new_constraints and self.failures: + if self.done and not self.replace and not self.new_constraints and self.failures: return 'ERR: ' + '; '.join(x.msg for x in self.failures) - return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}' + return f'{self.done} {self.replace} {self.new_constraints} {self.failures}' def skip_for_now() -> CheckResult: return CheckResult(done=False) +def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult: + return CheckResult(replace=ReplaceVariable(var, typ)) + def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult: return CheckResult(new_constraints=list(lst)) @@ -94,12 +142,8 @@ class ConstraintBase: def check(self) -> CheckResult: raise NotImplementedError(self) - def apply(self, action: Action) -> None: - if isinstance(action, ReplaceVariable): - self.replace_variable(action.var, action.typ) - return - - raise NotImplementedError(action) + def complexity(self) -> int: + raise NotImplementedError def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: pass @@ -142,6 +186,9 @@ class FromLiteralInteger(ConstraintBase): def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: self.type5 = replace_variable(self.type5, var, typ) + def complexity(self) -> int: + return 100 + complexity(self.type5) + def __str__(self) -> str: return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}' @@ -175,8 +222,11 @@ class FromLiteralFloat(ConstraintBase): def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: self.type5 = replace_variable(self.type5, var, typ) + def complexity(self) -> int: + return 100 + complexity(self.type5) + def __str__(self) -> str: - return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}' + return f'FromLiteralFloat {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}' class FromLiteralBytes(ConstraintBase): __slots__ = ('type5', 'literal', ) @@ -203,32 +253,125 @@ class FromLiteralBytes(ConstraintBase): def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: self.type5 = replace_variable(self.type5, var, typ) + def complexity(self) -> int: + return 100 + complexity(self.type5) + def __str__(self) -> str: return f'FromLiteralBytes {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}' class UnifyTypesConstraint(ConstraintBase): - __slots__ = ("lft", "rgt",) + __slots__ = ("lft", "rgt", "prefix", ) - def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr) -> None: + def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr, prefix: str | None = None) -> None: super().__init__(ctx, sourceref) self.lft = lft self.rgt = rgt + self.prefix = prefix def check(self) -> CheckResult: - result = unify(self.lft, self.rgt) + lft = self.lft + rgt = self.rgt - if isinstance(result, Failure): - return CheckResult(failures=[result]) + if lft == self.rgt: + return ok() - return CheckResult(actions=result) + if lft.kind != rgt.kind: + return fail("Kind mismatch") + + + + if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType): + return fail("Not the same type") + + if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable): + return replace(rgt, lft) + + if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor): + raise NotImplementedError # Should have been caught by kind check above + + if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication): + return fail("Not the same type" if is_concrete(rgt) else "Type shape mismatch") + + + + if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType): + return replace(lft, rgt) + + if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable): + return replace(lft, rgt) + + if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor): + return replace(lft, rgt) + + if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication): + if occurs(lft, rgt): + return fail("One type occurs in the other") + + return replace(lft, rgt) + + + + if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType): + raise NotImplementedError # Should have been caught by kind check above + + if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable): + return replace(rgt, lft) + + if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor): + return fail("Not the same type constructor") + + if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication): + return fail("Not the same type constructor") + + + + if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType): + return fail("Not the same type" if is_concrete(lft) else "Type shape mismatch") + + if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable): + if occurs(rgt, lft): + return fail("One type occurs in the other") + + return replace(rgt, lft) + + if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor): + return fail("Not the same type constructor") + + if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication): + + ## USABILITY HACK + ## Often, we have two type applications in the same go + ## If so, resolve it in a single step + ## (Helps with debugging function unification) + ## This *should* not affect the actual type unification + ## It's just one less call to UnifyTypesConstraint.check + if isinstance(lft.constructor, TypeApplication) and isinstance(rgt.constructor, TypeApplication): + return new_constraints([ + UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.constructor, rgt.constructor.constructor), + UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.argument, rgt.constructor.argument), + UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument), + ]) + + + return new_constraints([ + UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor, rgt.constructor), + UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument), + ]) + + + raise NotImplementedError(lft, rgt) def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None: self.lft = replace_variable(self.lft, var, typ) self.rgt = replace_variable(self.rgt, var, typ) + def complexity(self) -> int: + return complexity(self.lft) + complexity(self.rgt) + def __str__(self) -> str: - return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}" + prefix = f'{self.prefix} :: ' if self.prefix else '' + return f"{prefix}{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}" class CanBeSubscriptedConstraint(ConstraintBase): __slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', ) @@ -290,6 +433,9 @@ class CanBeSubscriptedConstraint(ConstraintBase): self.container_type5 = replace_variable(self.container_type5, var, typ) self.index_type5 = replace_variable(self.index_type5, var, typ) + def complexity(self) -> int: + return 100 + complexity(self.ret_type5) + complexity(self.container_type5) + complexity(self.index_type5) + def __str__(self) -> str: return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}" @@ -333,6 +479,9 @@ class CanAccessStructMemberConstraint(ConstraintBase): self.ret_type5 = replace_variable(self.ret_type5, var, typ) self.struct_type5 = replace_variable(self.struct_type5, var, typ) + def complexity(self) -> int: + return 100 + complexity(self.ret_type5) + complexity(self.struct_type5) + def __str__(self) -> str: st_args = self.ctx.build.type5_is_struct(self.struct_type5) member_dict = dict(st_args or []) @@ -404,6 +553,9 @@ class FromTupleConstraint(ConstraintBase): for x in self.member_type5_list ] + def complexity(self) -> int: + return 100 + complexity(self.ret_type5) + sum(complexity(x) for x in self.member_type5_list) + def __str__(self) -> str: args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list) return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )' @@ -450,6 +602,24 @@ class TypeClassInstanceExistsConstraint(ConstraintBase): for x in self.arg_list ] + def complexity(self) -> int: + return 100 + sum(complexity(x) for x in self.arg_list) + def __str__(self) -> str: args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list) return f'Exists {self.typeclass} {args}' + +def complexity(expr: TypeExpr) -> int: + if isinstance(expr, AtomicType | TypeLevelNat): + return 1 + + if isinstance(expr, TypeConstructor): + return 2 + + if isinstance(expr, TypeVariable): + return 5 + + if isinstance(expr, TypeApplication): + return complexity(expr.constructor) + complexity(expr.argument) + + raise NotImplementedError(expr) diff --git a/phasm/type5/fromast.py b/phasm/type5/fromast.py index f5c4ae0..f3e62a4 100644 --- a/phasm/type5/fromast.py +++ b/phasm/type5/fromast.py @@ -16,7 +16,7 @@ from .constraints import ( UnifyTypesConstraint, ) from .kindexpr import KindExpr -from .typeexpr import TypeApplication, TypeVariable, instantiate +from .typeexpr import TypeApplication, TypeExpr, TypeVariable, is_concrete ConstraintGenerator = Generator[ConstraintBase, None, None] @@ -90,14 +90,41 @@ def expression_constant(ctx: Context, inp: ourlang.Constant, phft: TypeVariable) raise NotImplementedError(inp) def expression_variable_reference(ctx: Context, inp: ourlang.VariableReference, phft: TypeVariable) -> ConstraintGenerator: - yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft) + yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft, prefix=inp.variable.name) def expression_binary_operator(ctx: Context, inp: ourlang.BinaryOp, phft: TypeVariable) -> ConstraintGenerator: - yield from expression_function_call(ctx, _binary_op_to_function(inp), phft) + yield from _expression_binary_operator_or_function_call( + ctx, + inp.operator, + inp.polytype_substitutions, + [inp.left, inp.right], + inp.sourceref, + f'({inp.operator.name})', + phft, + ) def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: TypeVariable) -> ConstraintGenerator: + yield from _expression_binary_operator_or_function_call( + ctx, + inp.function, + inp.polytype_substitutions, + inp.arguments, + inp.sourceref, + inp.function.name, + phft, + ) + +def _expression_binary_operator_or_function_call( + ctx: Context, + function: ourlang.Function | ourlang.FunctionParam, + polytype_substitutions: dict[TypeVariable, TypeExpr], + arguments: list[ourlang.Expression], + sourceref: ourlang.SourceRef, + function_name: str, + phft: TypeVariable, +) -> ConstraintGenerator: arg_typ_list = [] - for arg in inp.arguments: + for arg in arguments: arg_tv = ctx.make_placeholder(arg) yield from expression(ctx, arg, arg_tv) arg_typ_list.append(arg_tv) @@ -105,34 +132,28 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Type def make_placeholder(x: KindExpr, p: str) -> TypeVariable: return ctx.make_placeholder(kind=x, prefix=p) - ftp5 = inp.function_instance.function.type5 + ftp5 = function.type5 assert ftp5 is not None if isinstance(ftp5, ConstrainedExpr): - ftp5 = instantiate_constrained(ftp5, {}, make_placeholder) + ftp5, phft_lookup = instantiate_constrained(ftp5, make_placeholder) + + for orig_tvar, tvar in phft_lookup.items(): + ctx.register_polytype_subsitutes(tvar, polytype_substitutions, orig_tvar) for type_constraint in ftp5.constraints: if isinstance(type_constraint, TypeClassConstraint): - yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, type_constraint.cls.name, type_constraint.variables) + yield TypeClassInstanceExistsConstraint(ctx, sourceref, type_constraint.cls.name, type_constraint.variables) continue raise NotImplementedError(type_constraint) ftp5 = ftp5.expr else: - ftp5 = instantiate(ftp5, {}) - - # We need an extra placeholder so that the inp.function_instance gets updated - phft2 = ctx.make_placeholder(inp.function_instance) - yield UnifyTypesConstraint( - ctx, - inp.sourceref, - ftp5, - phft2, - ) + assert is_concrete(ftp5) expr_type = ctx.build.type5_make_function(arg_typ_list + [phft]) - yield UnifyTypesConstraint(ctx, inp.sourceref, phft2, expr_type) + yield UnifyTypesConstraint(ctx, sourceref, ftp5, expr_type, prefix=function_name) def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: TypeVariable) -> ConstraintGenerator: assert inp.function.type5 is not None # Todo: Make not nullable @@ -141,7 +162,7 @@ def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, if isinstance(ftp5, ConstrainedExpr): ftp5 = ftp5.expr - yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft) + yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft, prefix=inp.function.name) def expression_tuple_instantiation(ctx: Context, inp: ourlang.TupleInstantiation, phft: TypeVariable) -> ConstraintGenerator: arg_typ_list = [] @@ -220,7 +241,7 @@ def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement type5 = fun.type5.expr if isinstance(fun.type5, ConstrainedExpr) else fun.type5 yield from expression(ctx, inp.value, phft) - yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft) + yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft, prefix=f'{fun.name} returns') def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf) -> ConstraintGenerator: test_phft = ctx.make_placeholder(inp.test) @@ -267,14 +288,3 @@ def module(ctx: Context, inp: ourlang.Module[Any]) -> ConstraintGenerator: yield from function(ctx, func) # TODO: Generalize? - -def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall: - """ - For typing purposes, a binary operator is just a function call. - - It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)` - """ - assert inp.sourceref is not None # TODO: sourceref required - call = ourlang.FunctionCall(inp.operator, inp.sourceref) - call.arguments = [inp.left, inp.right] - return call diff --git a/phasm/type5/solver.py b/phasm/type5/solver.py index e4a4c1b..e97022f 100644 --- a/phasm/type5/solver.py +++ b/phasm/type5/solver.py @@ -3,8 +3,7 @@ from typing import Any from ..ourlang import Module from .constraints import ConstraintBase, Context from .fromast import phasm_type5_generate_constraints -from .typeexpr import TypeExpr, TypeVariable, replace_variable -from .unify import ReplaceVariable +from .typeexpr import TypeExpr, TypeVariable, is_concrete, replace_variable MAX_RESTACK_COUNT = 100 @@ -31,8 +30,7 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None: print("Validating") new_constraint_list: list[ConstraintBase] = [] - while constraint_list: - constraint = constraint_list.pop(0) + for constraint in sorted(constraint_list, key=lambda x: x.complexity()): result = constraint.check() if verbose: @@ -44,61 +42,41 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None: # Means it checks out and we don't need do anything continue - while result.actions: - action = result.actions.pop(0) + if result.replace is not None: + action_var = result.replace.var + assert action_var not in placeholder_types # When does this happen? - if isinstance(action, ReplaceVariable): - action_var: TypeExpr = action.var - while isinstance(action_var, TypeVariable) and action_var in placeholder_types: - # TODO: Does this still happen? - action_var = placeholder_types[action_var] + action_typ: TypeExpr = result.replace.typ + assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen? - action_typ: TypeExpr = action.typ - while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types: - # TODO: Does this still happen? - action_typ = placeholder_types[action_typ] + assert action_var != action_typ # When does this happen? - # print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ)) + # Ensure all existing found types are updated + placeholder_types = { + k: replace_variable(v, action_var, action_typ) + for k, v in placeholder_types.items() + } + placeholder_types[action_var] = action_typ - if action_var == action_typ: + for oth_const in new_constraint_list + constraint_list: + if oth_const is constraint and result.done: continue - if not isinstance(action_var, TypeVariable) and isinstance(action_typ, TypeVariable): - action_typ, action_var = action_var, action_typ + old_str = str(oth_const) + oth_const.replace_variable(action_var, action_typ) + new_str = str(oth_const) - if isinstance(action_var, TypeVariable): - # Ensure all existing found types are updated - placeholder_types = { - k: replace_variable(v, action_var, action_typ) - for k, v in placeholder_types.items() - } - placeholder_types[action_var] = action_typ - - for oth_const in new_constraint_list + constraint_list: - if oth_const is constraint and result.done: - continue - - old_str = str(oth_const) - oth_const.replace_variable(action_var, action_typ) - new_str = str(oth_const) - - if verbose and old_str != new_str: - print(f"{oth_const.sourceref!s} => - {old_str!s}") - print(f"{oth_const.sourceref!s} => + {new_str!s}") - continue - - error_list.append((str(constraint.sourceref), str(constraint), "Not the same type", )) - if verbose: - print(f"{constraint.sourceref!s} => ERR: Conflict in applying {action.to_str(inp.build.type5_name)}") - continue - - # Action of unsupported type - raise NotImplementedError(action) + if verbose and old_str != new_str: + print(f"{oth_const.sourceref!s} => - {old_str!s}") + print(f"{oth_const.sourceref!s} => + {new_str!s}") for failure in result.failures: error_list.append((str(constraint.sourceref), str(constraint), failure.msg, )) new_constraint_list.extend(result.new_constraints) + if verbose: + for new_const in result.new_constraints: + print(f"{oth_const.sourceref!s} => + {new_const!s}") if result.done: continue @@ -124,8 +102,11 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None: if expression is None: continue - new_type5 = placeholder_types[placeholder] - while isinstance(new_type5, TypeVariable): - new_type5 = placeholder_types[new_type5] + resolved_type5 = placeholder_types[placeholder] + assert is_concrete(resolved_type5) # When does this happen? + expression.type5 = resolved_type5 - expression.type5 = new_type5 + for placeholder, (ptst_map, orig_tvar) in ctx.ptst_update.items(): + resolved_type5 = placeholder_types[placeholder] + assert is_concrete(resolved_type5) # When does this happen? + ptst_map[orig_tvar] = resolved_type5 diff --git a/phasm/type5/typerouter.py b/phasm/type5/typerouter.py index 1609b36..b2c6464 100644 --- a/phasm/type5/typerouter.py +++ b/phasm/type5/typerouter.py @@ -4,6 +4,7 @@ from .typeexpr import ( TypeApplication, TypeConstructor, TypeExpr, + TypeLevelNat, TypeVariable, ) @@ -21,6 +22,9 @@ class TypeRouter[T]: def when_record(self, typ: Record) -> T: raise NotImplementedError(typ) + def when_type_level_nat(self, typ: TypeLevelNat) -> T: + raise NotImplementedError(typ) + def when_variable(self, typ: TypeVariable) -> T: raise NotImplementedError(typ) @@ -37,6 +41,9 @@ class TypeRouter[T]: if isinstance(typ, TypeConstructor): return self.when_constructor(typ) + if isinstance(typ, TypeLevelNat): + return self.when_type_level_nat(typ) + if isinstance(typ, TypeVariable): return self.when_variable(typ) diff --git a/phasm/type5/unify.py b/phasm/type5/unify.py deleted file mode 100644 index a8530e4..0000000 --- a/phasm/type5/unify.py +++ /dev/null @@ -1,128 +0,0 @@ -from dataclasses import dataclass -from typing import Callable - -from .typeexpr import ( - AtomicType, - TypeApplication, - TypeConstructor, - TypeExpr, - TypeVariable, - is_concrete, - occurs, -) - - -@dataclass -class Failure: - """ - Both types are already different - cannot be unified. - """ - msg: str - -@dataclass -class Action: - def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: - raise NotImplementedError - -class ActionList(list[Action]): - def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: - return '{' + ', '.join((x.to_str(type_namer) for x in self)) + '}' - -UnifyResult = Failure | ActionList - -@dataclass -class ReplaceVariable(Action): - var: TypeVariable - typ: TypeExpr - - def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str: - return f'{self.var.name} := {type_namer(self.typ)}' - -def unify(lft: TypeExpr, rgt: TypeExpr) -> UnifyResult: - """ - Be warned: This only matches type variables with other variables or types - - it does not apply substituions nor does it validate if the matching - pairs are correct. - - TODO: Remove this. It should be part of UnifyTypesConstraint - and should just generate new constraints for applications. - """ - if lft == rgt: - return ActionList() - - if lft.kind != rgt.kind: - return Failure("Kind mismatch") - - - - if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType): - return Failure("Not the same type") - - if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable): - return ActionList([ReplaceVariable(rgt, lft)]) - - if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor): - raise NotImplementedError # Should have been caught by kind check above - - if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication): - if is_concrete(rgt): - return Failure("Not the same type") - - return Failure("Type shape mismatch") - - - - if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType): - return unify(rgt, lft) - - if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable): - return ActionList([ReplaceVariable(lft, rgt)]) - - if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor): - return ActionList([ReplaceVariable(lft, rgt)]) - - if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication): - if occurs(lft, rgt): - return Failure("One type occurs in the other") - - return ActionList([ReplaceVariable(lft, rgt)]) - - - - if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType): - return unify(rgt, lft) - - if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable): - return unify(rgt, lft) - - if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor): - return Failure("Not the same type constructor") - - if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication): - return Failure("Not the same type constructor") - - - - if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType): - return unify(rgt, lft) - - if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable): - return unify(rgt, lft) - - if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor): - return unify(rgt, lft) - - if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication): - con_res = unify(lft.constructor, rgt.constructor) - if isinstance(con_res, Failure): - return con_res - - arg_res = unify(lft.argument, rgt.argument) - if isinstance(arg_res, Failure): - return arg_res - - return ActionList(con_res + arg_res) - - - - return Failure('Not implemented') diff --git a/tests/integration/test_lang/generator.md b/tests/integration/test_lang/generator.md index d70f653..a74e83b 100644 --- a/tests/integration/test_lang/generator.md +++ b/tests/integration/test_lang/generator.md @@ -305,8 +305,5 @@ def testEntry() -> i32: ``` ```py -if TYPE_NAME.startswith('tuple_') or TYPE_NAME.startswith('static_array_') or TYPE_NAME.startswith('dynamic_array_'): - expect_type_error('Not the same type constructor') -else: - expect_type_error('Not the same type') +expect_type_error('Not the same type') ``` diff --git a/tests/integration/test_lang/test_function_calls.py b/tests/integration/test_lang/test_function_calls.py index 004e23b..70e2519 100644 --- a/tests/integration/test_lang/test_function_calls.py +++ b/tests/integration/test_lang/test_function_calls.py @@ -1,5 +1,7 @@ import pytest +from phasm.type5.solver import Type5SolverException + from ..helpers import Suite @@ -38,24 +40,22 @@ def test_call_post_defined(): code_py = """ @exported def testEntry() -> i32: - return helper(10, 3) + return helper(13) -def helper(left: i32, right: i32) -> i32: - return left - right +def helper(left: i32) -> i32: + return left """ result = Suite(code_py).run_code() - assert 7 == result.returned_value + assert 13 == result.returned_value @pytest.mark.integration_test -@pytest.mark.skip('FIXME: Type checking') def test_call_invalid_type(): code_py = """ def helper(left: i32) -> i32: return left() """ - result = Suite(code_py).run_code() - - assert 7 == result.returned_value + with pytest.raises(Type5SolverException, match=r'i32 ~ Callable\[i32\]'): + Suite(code_py).run_code() diff --git a/tests/integration/test_lang/test_second_order_functions.py b/tests/integration/test_lang/test_second_order_functions.py index a598ce4..67272d8 100644 --- a/tests/integration/test_lang/test_second_order_functions.py +++ b/tests/integration/test_lang/test_second_order_functions.py @@ -91,8 +91,7 @@ def testEntry() -> i32: return action(double, 13.0) """ - match = r'Callable\[i32, i32\] ~ Callable\[f32, [^]]+\]' - with pytest.raises(Type5SolverException, match=match): + with pytest.raises(Type5SolverException, match='i32 ~ f32'): Suite(code_py).run_code() @pytest.mark.integration_test @@ -109,8 +108,7 @@ def testEntry() -> i32: return action(double, 13) """ - match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[f32, i32\], p_[0-9]+, [^]]+\]' - with pytest.raises(Type5SolverException, match=match): + with pytest.raises(Type5SolverException, match='i32 ~ f32'): Suite(code_py).run_code() @pytest.mark.integration_test @@ -127,14 +125,14 @@ def testEntry() -> f32: return action(double, 13) """ - with pytest.raises(Type5SolverException, match='f32 ~ i32'): + with pytest.raises(Type5SolverException, match='i32 ~ f32'): Suite(code_py).run_code() @pytest.mark.integration_test def test_sof_function_with_wrong_return_type_pass(): code_py = """ def double(left: i32) -> f32: - return convert(left) * 2 + return convert(left) * 2.0 def action(applicable: Callable[i32, i32], left: i32) -> i32: return applicable(left) @@ -144,8 +142,7 @@ def testEntry() -> i32: return action(double, 13) """ - match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[i32, f32\], p_[0-9]+, [^]]+\]' - with pytest.raises(Type5SolverException, match=match): + with pytest.raises(Type5SolverException, match='i32 ~ f32'): Suite(code_py).run_code() @pytest.mark.integration_test @@ -179,12 +176,12 @@ def testEntry() -> i32: return action(double, 13, 14) """ - match = r'Callable\[Callable\[i32, i32, i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]' + match = r'Callable\[i32, i32\] ~ i32' with pytest.raises(Type5SolverException, match=match): Suite(code_py).run_code() @pytest.mark.integration_test -def test_sof_too_many_args_use(): +def test_sof_too_many_args_use_0(): code_py = """ def thirteen() -> i32: return 13 @@ -197,12 +194,30 @@ def testEntry() -> i32: return action(thirteen, 13) """ - match = r'Callable\[i32\] ~ Callable\[i32, p_[0-9]+\]' + match = r'\(\) ~ i32' with pytest.raises(Type5SolverException, match=match): Suite(code_py).run_code(verbose=True) @pytest.mark.integration_test -def test_sof_too_many_args_pass(): +def test_sof_too_many_args_use_1(): + code_py = """ +def thirteen(x: i32) -> i32: + return x + +def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32: + return applicable(left, right) + +@exported +def testEntry() -> i32: + return action(thirteen, 13, 26) +""" + + match = r'i32 ~ Callable\[i32, i32\]' + with pytest.raises(Type5SolverException, match=match): + Suite(code_py).run_code(verbose=True) + +@pytest.mark.integration_test +def test_sof_too_many_args_pass_0(): code_py = """ def double(left: i32) -> i32: return left * 2 @@ -215,6 +230,24 @@ def testEntry() -> i32: return action(double, 13, 14) """ - match = r'Callable\[Callable\[i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]' + match = r'\(\) ~ i32' + with pytest.raises(Type5SolverException, match=match): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_sof_too_many_args_pass_1(): + code_py = """ +def double(left: i32, right: i32) -> i32: + return left * right + +def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32: + return applicable(left) + +@exported +def testEntry() -> i32: + return action(double, 13, 14) +""" + + match = r'i32 ~ Callable\[i32, i32\]' with pytest.raises(Type5SolverException, match=match): Suite(code_py).run_code() diff --git a/tests/integration/test_typeclasses/test_foldable.py b/tests/integration/test_typeclasses/test_foldable.py index d516ae5..476c09a 100644 --- a/tests/integration/test_typeclasses/test_foldable.py +++ b/tests/integration/test_typeclasses/test_foldable.py @@ -168,12 +168,9 @@ def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32: return foldl(x, y, z) """ - match = { - 'i8': 'Type shape mismatch', - 'i8[3]': 'Kind mismatch', - } + match = 'Type shape mismatch' - with pytest.raises(Type5SolverException, match=match[in_typ]): + with pytest.raises(Type5SolverException, match=match): Suite(code_py).run_code() @pytest.mark.integration_test