Removes the weird second step unify
It is now part of the normal constraints. Added a special workaround for functions, since otherwise the output is a bit redundant and quite confusing. Also, constraints are now processed in order of complexity. This does not affect type safety. It uses a bit more CPU. But it makes the output that much easier to read. Also, removes the weird FunctionInstance hack. Instead, the more industry standard way of annotation the types on the function call is used. As always, this requires some hackyness for Subscriptable. Also, adds a few comments to the type unification to help with debugging. Also, prints out the new constraints that are received.
This commit is contained in:
parent
3d504e3d79
commit
7df9d5af12
@ -38,7 +38,6 @@ def load(build: BuildBase[Any]) -> None:
|
||||
build.register_type_class(Sized)
|
||||
|
||||
def wasm_dynamic_array_len(g: WasmGenerator, tv_map: dict[str, TypeExpr]) -> None:
|
||||
print('tv_map', tv_map)
|
||||
del tv_map
|
||||
# The length is stored in the first 4 bytes
|
||||
g.i32.load()
|
||||
|
||||
@ -8,6 +8,7 @@ from ..type5.typeexpr import (
|
||||
TypeApplication,
|
||||
TypeConstructor,
|
||||
TypeExpr,
|
||||
TypeLevelNat,
|
||||
TypeVariable,
|
||||
)
|
||||
from ..type5.typerouter import TypeRouter
|
||||
@ -102,6 +103,9 @@ class TypeName(BuildTypeRouter[str]):
|
||||
def when_tuple(self, tp_args: list[TypeExpr]) -> str:
|
||||
return '(' + ', '.join(map(self, tp_args)) + ', )'
|
||||
|
||||
def when_type_level_nat(self, typ: TypeLevelNat) -> str:
|
||||
return str(typ.value)
|
||||
|
||||
def when_variable(self, typ: TypeVariable) -> str:
|
||||
return typ.name
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ def expression(inp: ourlang.Expression) -> str:
|
||||
return str(inp.variable.name)
|
||||
|
||||
if isinstance(inp, ourlang.BinaryOp):
|
||||
return f'{expression(inp.left)} {inp.operator.function.name} {expression(inp.right)}'
|
||||
return f'{expression(inp.left)} {inp.operator.name} {expression(inp.right)}'
|
||||
|
||||
if isinstance(inp, ourlang.FunctionCall):
|
||||
args = ', '.join(
|
||||
@ -75,10 +75,10 @@ def expression(inp: ourlang.Expression) -> str:
|
||||
for arg in inp.arguments
|
||||
)
|
||||
|
||||
if isinstance(inp.function_instance.function, ourlang.StructConstructor):
|
||||
return f'{inp.function_instance.function.struct_type5.name}({args})'
|
||||
if isinstance(inp.function, ourlang.StructConstructor):
|
||||
return f'{inp.function.struct_type5.name}({args})'
|
||||
|
||||
return f'{inp.function_instance.function.name}({args})'
|
||||
return f'{inp.function.name}({args})'
|
||||
|
||||
if isinstance(inp, ourlang.FunctionReference):
|
||||
return str(inp.function.name)
|
||||
|
||||
@ -11,7 +11,14 @@ from .build.typerouter import BuildTypeRouter
|
||||
from .stdlib import alloc as stdlib_alloc
|
||||
from .stdlib import types as stdlib_types
|
||||
from .type5.constrainedexpr import ConstrainedExpr
|
||||
from .type5.typeexpr import AtomicType, TypeApplication, TypeExpr, is_concrete
|
||||
from .type5.typeexpr import (
|
||||
AtomicType,
|
||||
TypeApplication,
|
||||
TypeExpr,
|
||||
TypeVariable,
|
||||
is_concrete,
|
||||
replace_variable,
|
||||
)
|
||||
from .wasm import (
|
||||
WasmTypeFloat32,
|
||||
WasmTypeFloat64,
|
||||
@ -148,32 +155,94 @@ def expression_subscript_tuple(wgn: WasmGenerator, mod: ourlang.Module[WasmGener
|
||||
expression(wgn, mod, inp.varref)
|
||||
wgn.add_statement(el_type_info.wasm_load_func, f'offset={offset}')
|
||||
|
||||
def expression_subscript_operator(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Subscript) -> None:
|
||||
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
ftp5 = mod.build.type_classes['Subscriptable'].operators['[]']
|
||||
fn_args = mod.build.type5_is_function(ftp5)
|
||||
assert fn_args is not None
|
||||
t_a = fn_args[0]
|
||||
assert isinstance(t_a, TypeApplication)
|
||||
t = t_a.constructor
|
||||
a = t_a.argument
|
||||
|
||||
assert isinstance(t, TypeVariable)
|
||||
assert isinstance(a, TypeVariable)
|
||||
|
||||
assert isinstance(inp.varref.type5, TypeApplication)
|
||||
t_expr = inp.varref.type5.constructor
|
||||
a_expr = inp.varref.type5.argument
|
||||
|
||||
_expression_binary_operator_or_function_call(
|
||||
wgn,
|
||||
mod,
|
||||
ourlang.BuiltinFunction('[]', ftp5),
|
||||
{
|
||||
t: t_expr,
|
||||
a: a_expr,
|
||||
},
|
||||
[inp.varref, inp.index],
|
||||
inp.type5,
|
||||
)
|
||||
|
||||
def expression_binary_op(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.BinaryOp) -> None:
|
||||
expression_function_call(wgn, mod, _binary_op_to_function(inp))
|
||||
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
_expression_binary_operator_or_function_call(
|
||||
wgn,
|
||||
mod,
|
||||
inp.operator,
|
||||
inp.polytype_substitutions,
|
||||
[inp.left, inp.right],
|
||||
inp.type5,
|
||||
)
|
||||
|
||||
def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.FunctionCall) -> None:
|
||||
for arg in inp.arguments:
|
||||
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
_expression_binary_operator_or_function_call(
|
||||
wgn,
|
||||
mod,
|
||||
inp.function,
|
||||
inp.polytype_substitutions,
|
||||
inp.arguments,
|
||||
inp.type5,
|
||||
)
|
||||
|
||||
def _expression_binary_operator_or_function_call(
|
||||
wgn: WasmGenerator,
|
||||
mod: ourlang.Module[WasmGenerator],
|
||||
function: ourlang.Function | ourlang.FunctionParam,
|
||||
polytype_substitutions: dict[TypeVariable, TypeExpr],
|
||||
arguments: list[ourlang.Expression],
|
||||
ret_type5: TypeExpr,
|
||||
) -> None:
|
||||
for arg in arguments:
|
||||
expression(wgn, mod, arg)
|
||||
|
||||
if isinstance(inp.function_instance.function, ourlang.BuiltinFunction):
|
||||
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
|
||||
if isinstance(function, ourlang.BuiltinFunction):
|
||||
ftp5 = function.type5
|
||||
if isinstance(ftp5, ConstrainedExpr):
|
||||
cexpr = ftp5
|
||||
ftp5 = ftp5.expr
|
||||
for tvar in cexpr.variables:
|
||||
ftp5 = replace_variable(ftp5, tvar, polytype_substitutions[tvar])
|
||||
assert _is_concrete(ftp5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
try:
|
||||
method_type, method_router = mod.build.methods[inp.function_instance.function.name]
|
||||
method_type, method_router = mod.build.methods[function.name]
|
||||
except KeyError:
|
||||
method_type, method_router = mod.build.operators[inp.function_instance.function.name]
|
||||
method_type, method_router = mod.build.operators[function.name]
|
||||
|
||||
impl_lookup = method_router.get((inp.function_instance.type5, ))
|
||||
assert impl_lookup is not None, (inp.function_instance.function.name, inp.function_instance.type5, )
|
||||
impl_lookup = method_router.get((ftp5, ))
|
||||
assert impl_lookup is not None, (function.name, ftp5, )
|
||||
kwargs, impl = impl_lookup
|
||||
impl(wgn, kwargs)
|
||||
return
|
||||
|
||||
if isinstance(inp.function_instance.function, ourlang.FunctionParam):
|
||||
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
fn_args = mod.build.type5_is_function(inp.function_instance.type5)
|
||||
assert fn_args is not None
|
||||
if isinstance(function, ourlang.FunctionParam):
|
||||
fn_args = mod.build.type5_is_function(function.type5)
|
||||
assert fn_args is not None, function.type5
|
||||
|
||||
params = [
|
||||
type5(mod, x)
|
||||
@ -182,11 +251,15 @@ def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerat
|
||||
|
||||
result = params.pop()
|
||||
|
||||
wgn.add_statement('local.get', '${}'.format(inp.function_instance.function.name))
|
||||
wgn.add_statement('local.get', '${}'.format(function.name))
|
||||
wgn.call_indirect(params=params, result=result)
|
||||
return
|
||||
|
||||
wgn.call(inp.function_instance.function.name)
|
||||
# TODO: Do similar subsitutions like we do for BuiltinFunction
|
||||
# when we get user space polymorphic functions
|
||||
# And then do similar lookup, and ensure we generate code for that variant
|
||||
|
||||
wgn.call(function.name)
|
||||
|
||||
def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Expression) -> None:
|
||||
"""
|
||||
@ -278,22 +351,7 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourl
|
||||
expression_subscript_tuple(wgn, mod, inp)
|
||||
return
|
||||
|
||||
inp_as_fc = ourlang.FunctionCall(
|
||||
ourlang.FunctionInstance(
|
||||
ourlang.BuiltinFunction('[]', mod.build.type_classes['Subscriptable'].operators['[]']),
|
||||
inp.sourceref,
|
||||
),
|
||||
inp.sourceref,
|
||||
)
|
||||
inp_as_fc.arguments = [inp.varref, inp.index]
|
||||
inp_as_fc.function_instance.type5 = mod.build.type5_make_function([
|
||||
inp.varref.type5,
|
||||
inp.index.type5,
|
||||
inp.type5,
|
||||
])
|
||||
inp_as_fc.type5 = inp.type5
|
||||
|
||||
expression_function_call(wgn, mod, inp_as_fc)
|
||||
expression_subscript_operator(wgn, mod, inp)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.AccessStructMember):
|
||||
@ -321,11 +379,11 @@ def statement_return(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun
|
||||
# Support tail calls
|
||||
# https://github.com/WebAssembly/tail-call
|
||||
# These help a lot with some functional programming techniques
|
||||
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function_instance.function is fun:
|
||||
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function is fun:
|
||||
for arg in inp.value.arguments:
|
||||
expression(wgn, mod, arg)
|
||||
|
||||
wgn.add_statement('return_call', '${}'.format(inp.value.function_instance.function.name))
|
||||
wgn.add_statement('return_call', '${}'.format(inp.value.function.name))
|
||||
return
|
||||
|
||||
expression(wgn, mod, inp.value)
|
||||
@ -595,14 +653,3 @@ def _type5_struct_offset(
|
||||
result += build.type5_alloc_size_member(memtyp)
|
||||
|
||||
raise RuntimeError('Member not found')
|
||||
|
||||
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
|
||||
"""
|
||||
For compilation purposes, a binary operator is just a function call.
|
||||
|
||||
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
|
||||
"""
|
||||
assert inp.sourceref is not None # TODO: sourceref required
|
||||
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
|
||||
call.arguments = [inp.left, inp.right]
|
||||
return call
|
||||
|
||||
@ -157,52 +157,39 @@ class BinaryOp(Expression):
|
||||
"""
|
||||
A binary operator expression within a statement
|
||||
"""
|
||||
__slots__ = ('operator', 'left', 'right', )
|
||||
__slots__ = ('operator', 'polytype_substitutions', 'left', 'right', )
|
||||
|
||||
operator: FunctionInstance
|
||||
operator: Function | FunctionParam
|
||||
polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]
|
||||
left: Expression
|
||||
right: Expression
|
||||
|
||||
def __init__(self, operator: FunctionInstance, left: Expression, right: Expression, sourceref: SourceRef) -> None:
|
||||
def __init__(self, operator: Function | FunctionParam, left: Expression, right: Expression, sourceref: SourceRef) -> None:
|
||||
super().__init__(sourceref=sourceref)
|
||||
|
||||
self.operator = operator
|
||||
self.polytype_substitutions = {}
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'BinaryOp({repr(self.operator)}, {repr(self.left)}, {repr(self.right)})'
|
||||
|
||||
class FunctionInstance(Expression):
|
||||
"""
|
||||
When calling a polymorphic function with concrete arguments, we can generate
|
||||
code for that specific instance of the function.
|
||||
"""
|
||||
__slots__ = ('function', )
|
||||
|
||||
function: Union['Function', 'FunctionParam']
|
||||
|
||||
def __init__(self, function: Union['Function', 'FunctionParam'], sourceref: SourceRef) -> None:
|
||||
super().__init__(sourceref=sourceref)
|
||||
|
||||
self.function = function
|
||||
|
||||
class FunctionCall(Expression):
|
||||
"""
|
||||
A function call expression within a statement
|
||||
"""
|
||||
__slots__ = ('function_instance', 'arguments', )
|
||||
__slots__ = ('function', 'polytype_substitutions', 'arguments', )
|
||||
|
||||
function_instance: FunctionInstance
|
||||
# TODO: FunctionInstance is wrong - we should have
|
||||
# substitutions: dict[TypeVariable, TypeExpr]
|
||||
# And it should have the same variables as the polytype (ConstrainedExpr) for function
|
||||
function: Function | FunctionParam
|
||||
polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]
|
||||
arguments: List[Expression]
|
||||
|
||||
def __init__(self, function_instance: FunctionInstance, sourceref: SourceRef) -> None:
|
||||
def __init__(self, function: Function | FunctionParam, sourceref: SourceRef) -> None:
|
||||
super().__init__(sourceref=sourceref)
|
||||
|
||||
self.function_instance = function_instance
|
||||
self.function = function
|
||||
self.polytype_substitutions = {}
|
||||
self.arguments = []
|
||||
|
||||
class FunctionReference(Expression):
|
||||
|
||||
@ -18,7 +18,6 @@ from .ourlang import (
|
||||
Expression,
|
||||
Function,
|
||||
FunctionCall,
|
||||
FunctionInstance,
|
||||
FunctionParam,
|
||||
FunctionReference,
|
||||
Module,
|
||||
@ -400,7 +399,7 @@ class OurVisitor[G]:
|
||||
raise NotImplementedError(f'Operator {operator}')
|
||||
|
||||
return BinaryOp(
|
||||
FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)),
|
||||
BuiltinFunction(operator, module.operators[operator]),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
|
||||
srf(module, node),
|
||||
@ -429,7 +428,7 @@ class OurVisitor[G]:
|
||||
raise NotImplementedError(f'Operator {operator}')
|
||||
|
||||
return BinaryOp(
|
||||
FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)),
|
||||
BuiltinFunction(operator, module.operators[operator]),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
|
||||
srf(module, node),
|
||||
@ -506,7 +505,7 @@ class OurVisitor[G]:
|
||||
|
||||
func = module.functions[node.func.id]
|
||||
|
||||
result = FunctionCall(FunctionInstance(func, srf(module, node)), sourceref=srf(module, node))
|
||||
result = FunctionCall(func, sourceref=srf(module, node))
|
||||
result.arguments.extend(
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
|
||||
for arg_expr in node.args
|
||||
|
||||
@ -42,9 +42,8 @@ class ConstrainedExpr:
|
||||
|
||||
def instantiate_constrained(
|
||||
constrainedexpr: ConstrainedExpr,
|
||||
known_map: dict[TypeVariable, TypeVariable],
|
||||
make_variable: Callable[[KindExpr, str], TypeVariable],
|
||||
) -> ConstrainedExpr:
|
||||
) -> tuple[ConstrainedExpr, dict[TypeVariable, TypeVariable]]:
|
||||
"""
|
||||
Instantiates a type expression and its constraints
|
||||
"""
|
||||
@ -61,4 +60,4 @@ def instantiate_constrained(
|
||||
x.instantiate(known_map)
|
||||
for x in constrainedexpr.constraints
|
||||
)
|
||||
return ConstrainedExpr(constrainedexpr.variables, expr, constraints)
|
||||
return ConstrainedExpr(constrainedexpr.variables, expr, constraints), known_map
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Iterable, Protocol, Sequence
|
||||
from typing import Any, Callable, Iterable, Protocol, Sequence, TypeAlias
|
||||
|
||||
from ..build.base import BuildBase
|
||||
from ..ourlang import SourceRef
|
||||
@ -9,13 +9,16 @@ from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt6
|
||||
from .kindexpr import KindExpr, Star
|
||||
from .record import Record
|
||||
from .typeexpr import (
|
||||
AtomicType,
|
||||
TypeApplication,
|
||||
TypeConstructor,
|
||||
TypeExpr,
|
||||
TypeLevelNat,
|
||||
TypeVariable,
|
||||
is_concrete,
|
||||
occurs,
|
||||
replace_variable,
|
||||
)
|
||||
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
|
||||
|
||||
|
||||
class ExpressionProtocol(Protocol):
|
||||
@ -28,50 +31,95 @@ class ExpressionProtocol(Protocol):
|
||||
The type to update
|
||||
"""
|
||||
|
||||
PolytypeSubsituteMap: TypeAlias = dict[TypeVariable, TypeExpr]
|
||||
|
||||
class Context:
|
||||
__slots__ = ("build", "placeholder_update", )
|
||||
__slots__ = ("build", "placeholder_update", "ptst_update", )
|
||||
|
||||
build: BuildBase[Any]
|
||||
placeholder_update: dict[TypeVariable, ExpressionProtocol | None]
|
||||
ptst_update: dict[TypeVariable, tuple[PolytypeSubsituteMap, TypeVariable]]
|
||||
|
||||
def __init__(self, build: BuildBase[Any]) -> None:
|
||||
self.build = build
|
||||
self.placeholder_update = {}
|
||||
self.ptst_update = {}
|
||||
|
||||
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star(), prefix: str = 'p') -> TypeVariable:
|
||||
res = TypeVariable(kind, f"{prefix}_{len(self.placeholder_update)}")
|
||||
self.placeholder_update[res] = arg
|
||||
return res
|
||||
|
||||
def register_polytype_subsitutes(self, tvar: TypeVariable, arg: PolytypeSubsituteMap, orig_var: TypeVariable) -> None:
|
||||
"""
|
||||
When `tvar` gets subsituted, also set the result in arg with orig_var as key
|
||||
|
||||
e.g.
|
||||
|
||||
(-) :: Callable[a, a, a]
|
||||
|
||||
def foo() -> u32:
|
||||
return 2 - 1
|
||||
|
||||
During typing, we instantiate a into a_3, and get the following constraints:
|
||||
- u8 ~ p_1
|
||||
- u8 ~ p_2
|
||||
- Exists NatNum a_3
|
||||
- Callable[a_3, a_3, a_3] ~ Callable[p_1, p_2, p_0]
|
||||
- u8 ~ p_0
|
||||
|
||||
When we resolve a_3, then on the call to `-`, we should note that a_3 got resolved
|
||||
to u32. But we need to use `a` as key, since that's what's used on the definition
|
||||
"""
|
||||
assert tvar in self.placeholder_update
|
||||
assert tvar not in self.ptst_update
|
||||
self.ptst_update[tvar] = (arg, orig_var)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Failure:
|
||||
"""
|
||||
Both types are already different - cannot be unified.
|
||||
"""
|
||||
msg: str
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReplaceVariable:
|
||||
var: TypeVariable
|
||||
typ: TypeExpr
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CheckResult:
|
||||
# TODO: Refactor this, don't think we use most of the variants
|
||||
_: dataclasses.KW_ONLY
|
||||
done: bool = True
|
||||
actions: ActionList = dataclasses.field(default_factory=ActionList)
|
||||
replace: ReplaceVariable | None = None
|
||||
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
|
||||
failures: list[Failure] = dataclasses.field(default_factory=list)
|
||||
|
||||
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
||||
if not self.done and not self.actions and not self.new_constraints and not self.failures:
|
||||
if not self.done and not self.replace and not self.new_constraints and not self.failures:
|
||||
return '(skip for now)'
|
||||
|
||||
if self.done and not self.actions and not self.new_constraints and not self.failures:
|
||||
if self.done and not self.replace and not self.new_constraints and not self.failures:
|
||||
return '(ok)'
|
||||
|
||||
if self.done and self.actions and not self.new_constraints and not self.failures:
|
||||
return self.actions.to_str(type_namer)
|
||||
if self.done and self.replace and not self.new_constraints and not self.failures:
|
||||
return f'{{{self.replace.var.name} := {type_namer(self.replace.typ)}}}'
|
||||
|
||||
if self.done and not self.actions and self.new_constraints and not self.failures:
|
||||
if self.done and not self.replace and self.new_constraints and not self.failures:
|
||||
return f'(got {len(self.new_constraints)} new constraints)'
|
||||
|
||||
if self.done and not self.actions and not self.new_constraints and self.failures:
|
||||
if self.done and not self.replace and not self.new_constraints and self.failures:
|
||||
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
|
||||
|
||||
return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}'
|
||||
return f'{self.done} {self.replace} {self.new_constraints} {self.failures}'
|
||||
|
||||
def skip_for_now() -> CheckResult:
|
||||
return CheckResult(done=False)
|
||||
|
||||
def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult:
|
||||
return CheckResult(replace=ReplaceVariable(var, typ))
|
||||
|
||||
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
|
||||
return CheckResult(new_constraints=list(lst))
|
||||
|
||||
@ -94,12 +142,8 @@ class ConstraintBase:
|
||||
def check(self) -> CheckResult:
|
||||
raise NotImplementedError(self)
|
||||
|
||||
def apply(self, action: Action) -> None:
|
||||
if isinstance(action, ReplaceVariable):
|
||||
self.replace_variable(action.var, action.typ)
|
||||
return
|
||||
|
||||
raise NotImplementedError(action)
|
||||
def complexity(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
pass
|
||||
@ -142,6 +186,9 @@ class FromLiteralInteger(ConstraintBase):
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
self.type5 = replace_variable(self.type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
||||
|
||||
@ -175,8 +222,11 @@ class FromLiteralFloat(ConstraintBase):
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
self.type5 = replace_variable(self.type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
||||
return f'FromLiteralFloat {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
||||
|
||||
class FromLiteralBytes(ConstraintBase):
|
||||
__slots__ = ('type5', 'literal', )
|
||||
@ -203,32 +253,125 @@ class FromLiteralBytes(ConstraintBase):
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
self.type5 = replace_variable(self.type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'FromLiteralBytes {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
||||
|
||||
class UnifyTypesConstraint(ConstraintBase):
|
||||
__slots__ = ("lft", "rgt",)
|
||||
__slots__ = ("lft", "rgt", "prefix", )
|
||||
|
||||
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr) -> None:
|
||||
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr, prefix: str | None = None) -> None:
|
||||
super().__init__(ctx, sourceref)
|
||||
|
||||
self.lft = lft
|
||||
self.rgt = rgt
|
||||
self.prefix = prefix
|
||||
|
||||
def check(self) -> CheckResult:
|
||||
result = unify(self.lft, self.rgt)
|
||||
lft = self.lft
|
||||
rgt = self.rgt
|
||||
|
||||
if isinstance(result, Failure):
|
||||
return CheckResult(failures=[result])
|
||||
if lft == self.rgt:
|
||||
return ok()
|
||||
|
||||
return CheckResult(actions=result)
|
||||
if lft.kind != rgt.kind:
|
||||
return fail("Kind mismatch")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
|
||||
return fail("Not the same type")
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
|
||||
return replace(rgt, lft)
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
|
||||
raise NotImplementedError # Should have been caught by kind check above
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
|
||||
return fail("Not the same type" if is_concrete(rgt) else "Type shape mismatch")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
|
||||
return replace(lft, rgt)
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
|
||||
return replace(lft, rgt)
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
|
||||
return replace(lft, rgt)
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
|
||||
if occurs(lft, rgt):
|
||||
return fail("One type occurs in the other")
|
||||
|
||||
return replace(lft, rgt)
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
|
||||
raise NotImplementedError # Should have been caught by kind check above
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
|
||||
return replace(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
|
||||
return fail("Not the same type constructor")
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
|
||||
return fail("Not the same type constructor")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
|
||||
return fail("Not the same type" if is_concrete(lft) else "Type shape mismatch")
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
|
||||
if occurs(rgt, lft):
|
||||
return fail("One type occurs in the other")
|
||||
|
||||
return replace(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
|
||||
return fail("Not the same type constructor")
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
|
||||
|
||||
## USABILITY HACK
|
||||
## Often, we have two type applications in the same go
|
||||
## If so, resolve it in a single step
|
||||
## (Helps with debugging function unification)
|
||||
## This *should* not affect the actual type unification
|
||||
## It's just one less call to UnifyTypesConstraint.check
|
||||
if isinstance(lft.constructor, TypeApplication) and isinstance(rgt.constructor, TypeApplication):
|
||||
return new_constraints([
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.constructor, rgt.constructor.constructor),
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.argument, rgt.constructor.argument),
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
|
||||
])
|
||||
|
||||
|
||||
return new_constraints([
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor, rgt.constructor),
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
|
||||
])
|
||||
|
||||
|
||||
raise NotImplementedError(lft, rgt)
|
||||
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
self.lft = replace_variable(self.lft, var, typ)
|
||||
self.rgt = replace_variable(self.rgt, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return complexity(self.lft) + complexity(self.rgt)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
|
||||
prefix = f'{self.prefix} :: ' if self.prefix else ''
|
||||
return f"{prefix}{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
|
||||
|
||||
class CanBeSubscriptedConstraint(ConstraintBase):
|
||||
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
|
||||
@ -290,6 +433,9 @@ class CanBeSubscriptedConstraint(ConstraintBase):
|
||||
self.container_type5 = replace_variable(self.container_type5, var, typ)
|
||||
self.index_type5 = replace_variable(self.index_type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.ret_type5) + complexity(self.container_type5) + complexity(self.index_type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
|
||||
|
||||
@ -333,6 +479,9 @@ class CanAccessStructMemberConstraint(ConstraintBase):
|
||||
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
|
||||
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.ret_type5) + complexity(self.struct_type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
|
||||
member_dict = dict(st_args or [])
|
||||
@ -404,6 +553,9 @@ class FromTupleConstraint(ConstraintBase):
|
||||
for x in self.member_type5_list
|
||||
]
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.ret_type5) + sum(complexity(x) for x in self.member_type5_list)
|
||||
|
||||
def __str__(self) -> str:
|
||||
args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list)
|
||||
return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )'
|
||||
@ -450,6 +602,24 @@ class TypeClassInstanceExistsConstraint(ConstraintBase):
|
||||
for x in self.arg_list
|
||||
]
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + sum(complexity(x) for x in self.arg_list)
|
||||
|
||||
def __str__(self) -> str:
|
||||
args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list)
|
||||
return f'Exists {self.typeclass} {args}'
|
||||
|
||||
def complexity(expr: TypeExpr) -> int:
|
||||
if isinstance(expr, AtomicType | TypeLevelNat):
|
||||
return 1
|
||||
|
||||
if isinstance(expr, TypeConstructor):
|
||||
return 2
|
||||
|
||||
if isinstance(expr, TypeVariable):
|
||||
return 5
|
||||
|
||||
if isinstance(expr, TypeApplication):
|
||||
return complexity(expr.constructor) + complexity(expr.argument)
|
||||
|
||||
raise NotImplementedError(expr)
|
||||
|
||||
@ -16,7 +16,7 @@ from .constraints import (
|
||||
UnifyTypesConstraint,
|
||||
)
|
||||
from .kindexpr import KindExpr
|
||||
from .typeexpr import TypeApplication, TypeVariable, instantiate
|
||||
from .typeexpr import TypeApplication, TypeExpr, TypeVariable, is_concrete
|
||||
|
||||
ConstraintGenerator = Generator[ConstraintBase, None, None]
|
||||
|
||||
@ -90,14 +90,41 @@ def expression_constant(ctx: Context, inp: ourlang.Constant, phft: TypeVariable)
|
||||
raise NotImplementedError(inp)
|
||||
|
||||
def expression_variable_reference(ctx: Context, inp: ourlang.VariableReference, phft: TypeVariable) -> ConstraintGenerator:
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft)
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft, prefix=inp.variable.name)
|
||||
|
||||
def expression_binary_operator(ctx: Context, inp: ourlang.BinaryOp, phft: TypeVariable) -> ConstraintGenerator:
|
||||
yield from expression_function_call(ctx, _binary_op_to_function(inp), phft)
|
||||
yield from _expression_binary_operator_or_function_call(
|
||||
ctx,
|
||||
inp.operator,
|
||||
inp.polytype_substitutions,
|
||||
[inp.left, inp.right],
|
||||
inp.sourceref,
|
||||
f'({inp.operator.name})',
|
||||
phft,
|
||||
)
|
||||
|
||||
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: TypeVariable) -> ConstraintGenerator:
|
||||
yield from _expression_binary_operator_or_function_call(
|
||||
ctx,
|
||||
inp.function,
|
||||
inp.polytype_substitutions,
|
||||
inp.arguments,
|
||||
inp.sourceref,
|
||||
inp.function.name,
|
||||
phft,
|
||||
)
|
||||
|
||||
def _expression_binary_operator_or_function_call(
|
||||
ctx: Context,
|
||||
function: ourlang.Function | ourlang.FunctionParam,
|
||||
polytype_substitutions: dict[TypeVariable, TypeExpr],
|
||||
arguments: list[ourlang.Expression],
|
||||
sourceref: ourlang.SourceRef,
|
||||
function_name: str,
|
||||
phft: TypeVariable,
|
||||
) -> ConstraintGenerator:
|
||||
arg_typ_list = []
|
||||
for arg in inp.arguments:
|
||||
for arg in arguments:
|
||||
arg_tv = ctx.make_placeholder(arg)
|
||||
yield from expression(ctx, arg, arg_tv)
|
||||
arg_typ_list.append(arg_tv)
|
||||
@ -105,34 +132,28 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Type
|
||||
def make_placeholder(x: KindExpr, p: str) -> TypeVariable:
|
||||
return ctx.make_placeholder(kind=x, prefix=p)
|
||||
|
||||
ftp5 = inp.function_instance.function.type5
|
||||
ftp5 = function.type5
|
||||
assert ftp5 is not None
|
||||
if isinstance(ftp5, ConstrainedExpr):
|
||||
ftp5 = instantiate_constrained(ftp5, {}, make_placeholder)
|
||||
ftp5, phft_lookup = instantiate_constrained(ftp5, make_placeholder)
|
||||
|
||||
for orig_tvar, tvar in phft_lookup.items():
|
||||
ctx.register_polytype_subsitutes(tvar, polytype_substitutions, orig_tvar)
|
||||
|
||||
for type_constraint in ftp5.constraints:
|
||||
if isinstance(type_constraint, TypeClassConstraint):
|
||||
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, type_constraint.cls.name, type_constraint.variables)
|
||||
yield TypeClassInstanceExistsConstraint(ctx, sourceref, type_constraint.cls.name, type_constraint.variables)
|
||||
continue
|
||||
|
||||
raise NotImplementedError(type_constraint)
|
||||
|
||||
ftp5 = ftp5.expr
|
||||
else:
|
||||
ftp5 = instantiate(ftp5, {})
|
||||
|
||||
# We need an extra placeholder so that the inp.function_instance gets updated
|
||||
phft2 = ctx.make_placeholder(inp.function_instance)
|
||||
yield UnifyTypesConstraint(
|
||||
ctx,
|
||||
inp.sourceref,
|
||||
ftp5,
|
||||
phft2,
|
||||
)
|
||||
assert is_concrete(ftp5)
|
||||
|
||||
expr_type = ctx.build.type5_make_function(arg_typ_list + [phft])
|
||||
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, phft2, expr_type)
|
||||
yield UnifyTypesConstraint(ctx, sourceref, ftp5, expr_type, prefix=function_name)
|
||||
|
||||
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: TypeVariable) -> ConstraintGenerator:
|
||||
assert inp.function.type5 is not None # Todo: Make not nullable
|
||||
@ -141,7 +162,7 @@ def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference,
|
||||
if isinstance(ftp5, ConstrainedExpr):
|
||||
ftp5 = ftp5.expr
|
||||
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft)
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft, prefix=inp.function.name)
|
||||
|
||||
def expression_tuple_instantiation(ctx: Context, inp: ourlang.TupleInstantiation, phft: TypeVariable) -> ConstraintGenerator:
|
||||
arg_typ_list = []
|
||||
@ -220,7 +241,7 @@ def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement
|
||||
type5 = fun.type5.expr if isinstance(fun.type5, ConstrainedExpr) else fun.type5
|
||||
|
||||
yield from expression(ctx, inp.value, phft)
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft)
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft, prefix=f'{fun.name} returns')
|
||||
|
||||
def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf) -> ConstraintGenerator:
|
||||
test_phft = ctx.make_placeholder(inp.test)
|
||||
@ -267,14 +288,3 @@ def module(ctx: Context, inp: ourlang.Module[Any]) -> ConstraintGenerator:
|
||||
yield from function(ctx, func)
|
||||
|
||||
# TODO: Generalize?
|
||||
|
||||
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
|
||||
"""
|
||||
For typing purposes, a binary operator is just a function call.
|
||||
|
||||
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
|
||||
"""
|
||||
assert inp.sourceref is not None # TODO: sourceref required
|
||||
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
|
||||
call.arguments = [inp.left, inp.right]
|
||||
return call
|
||||
|
||||
@ -3,8 +3,7 @@ from typing import Any
|
||||
from ..ourlang import Module
|
||||
from .constraints import ConstraintBase, Context
|
||||
from .fromast import phasm_type5_generate_constraints
|
||||
from .typeexpr import TypeExpr, TypeVariable, replace_variable
|
||||
from .unify import ReplaceVariable
|
||||
from .typeexpr import TypeExpr, TypeVariable, is_concrete, replace_variable
|
||||
|
||||
MAX_RESTACK_COUNT = 100
|
||||
|
||||
@ -31,8 +30,7 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
|
||||
print("Validating")
|
||||
|
||||
new_constraint_list: list[ConstraintBase] = []
|
||||
while constraint_list:
|
||||
constraint = constraint_list.pop(0)
|
||||
for constraint in sorted(constraint_list, key=lambda x: x.complexity()):
|
||||
result = constraint.check()
|
||||
|
||||
if verbose:
|
||||
@ -44,29 +42,15 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
|
||||
# Means it checks out and we don't need do anything
|
||||
continue
|
||||
|
||||
while result.actions:
|
||||
action = result.actions.pop(0)
|
||||
if result.replace is not None:
|
||||
action_var = result.replace.var
|
||||
assert action_var not in placeholder_types # When does this happen?
|
||||
|
||||
if isinstance(action, ReplaceVariable):
|
||||
action_var: TypeExpr = action.var
|
||||
while isinstance(action_var, TypeVariable) and action_var in placeholder_types:
|
||||
# TODO: Does this still happen?
|
||||
action_var = placeholder_types[action_var]
|
||||
action_typ: TypeExpr = result.replace.typ
|
||||
assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen?
|
||||
|
||||
action_typ: TypeExpr = action.typ
|
||||
while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types:
|
||||
# TODO: Does this still happen?
|
||||
action_typ = placeholder_types[action_typ]
|
||||
assert action_var != action_typ # When does this happen?
|
||||
|
||||
# print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ))
|
||||
|
||||
if action_var == action_typ:
|
||||
continue
|
||||
|
||||
if not isinstance(action_var, TypeVariable) and isinstance(action_typ, TypeVariable):
|
||||
action_typ, action_var = action_var, action_typ
|
||||
|
||||
if isinstance(action_var, TypeVariable):
|
||||
# Ensure all existing found types are updated
|
||||
placeholder_types = {
|
||||
k: replace_variable(v, action_var, action_typ)
|
||||
@ -85,20 +69,14 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
|
||||
if verbose and old_str != new_str:
|
||||
print(f"{oth_const.sourceref!s} => - {old_str!s}")
|
||||
print(f"{oth_const.sourceref!s} => + {new_str!s}")
|
||||
continue
|
||||
|
||||
error_list.append((str(constraint.sourceref), str(constraint), "Not the same type", ))
|
||||
if verbose:
|
||||
print(f"{constraint.sourceref!s} => ERR: Conflict in applying {action.to_str(inp.build.type5_name)}")
|
||||
continue
|
||||
|
||||
# Action of unsupported type
|
||||
raise NotImplementedError(action)
|
||||
|
||||
for failure in result.failures:
|
||||
error_list.append((str(constraint.sourceref), str(constraint), failure.msg, ))
|
||||
|
||||
new_constraint_list.extend(result.new_constraints)
|
||||
if verbose:
|
||||
for new_const in result.new_constraints:
|
||||
print(f"{oth_const.sourceref!s} => + {new_const!s}")
|
||||
|
||||
if result.done:
|
||||
continue
|
||||
@ -124,8 +102,11 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
|
||||
if expression is None:
|
||||
continue
|
||||
|
||||
new_type5 = placeholder_types[placeholder]
|
||||
while isinstance(new_type5, TypeVariable):
|
||||
new_type5 = placeholder_types[new_type5]
|
||||
resolved_type5 = placeholder_types[placeholder]
|
||||
assert is_concrete(resolved_type5) # When does this happen?
|
||||
expression.type5 = resolved_type5
|
||||
|
||||
expression.type5 = new_type5
|
||||
for placeholder, (ptst_map, orig_tvar) in ctx.ptst_update.items():
|
||||
resolved_type5 = placeholder_types[placeholder]
|
||||
assert is_concrete(resolved_type5) # When does this happen?
|
||||
ptst_map[orig_tvar] = resolved_type5
|
||||
|
||||
@ -4,6 +4,7 @@ from .typeexpr import (
|
||||
TypeApplication,
|
||||
TypeConstructor,
|
||||
TypeExpr,
|
||||
TypeLevelNat,
|
||||
TypeVariable,
|
||||
)
|
||||
|
||||
@ -21,6 +22,9 @@ class TypeRouter[T]:
|
||||
def when_record(self, typ: Record) -> T:
|
||||
raise NotImplementedError(typ)
|
||||
|
||||
def when_type_level_nat(self, typ: TypeLevelNat) -> T:
|
||||
raise NotImplementedError(typ)
|
||||
|
||||
def when_variable(self, typ: TypeVariable) -> T:
|
||||
raise NotImplementedError(typ)
|
||||
|
||||
@ -37,6 +41,9 @@ class TypeRouter[T]:
|
||||
if isinstance(typ, TypeConstructor):
|
||||
return self.when_constructor(typ)
|
||||
|
||||
if isinstance(typ, TypeLevelNat):
|
||||
return self.when_type_level_nat(typ)
|
||||
|
||||
if isinstance(typ, TypeVariable):
|
||||
return self.when_variable(typ)
|
||||
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
from .typeexpr import (
|
||||
AtomicType,
|
||||
TypeApplication,
|
||||
TypeConstructor,
|
||||
TypeExpr,
|
||||
TypeVariable,
|
||||
is_concrete,
|
||||
occurs,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Failure:
|
||||
"""
|
||||
Both types are already different - cannot be unified.
|
||||
"""
|
||||
msg: str
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
class ActionList(list[Action]):
|
||||
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
||||
return '{' + ', '.join((x.to_str(type_namer) for x in self)) + '}'
|
||||
|
||||
UnifyResult = Failure | ActionList
|
||||
|
||||
@dataclass
|
||||
class ReplaceVariable(Action):
|
||||
var: TypeVariable
|
||||
typ: TypeExpr
|
||||
|
||||
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
||||
return f'{self.var.name} := {type_namer(self.typ)}'
|
||||
|
||||
def unify(lft: TypeExpr, rgt: TypeExpr) -> UnifyResult:
|
||||
"""
|
||||
Be warned: This only matches type variables with other variables or types
|
||||
- it does not apply substituions nor does it validate if the matching
|
||||
pairs are correct.
|
||||
|
||||
TODO: Remove this. It should be part of UnifyTypesConstraint
|
||||
and should just generate new constraints for applications.
|
||||
"""
|
||||
if lft == rgt:
|
||||
return ActionList()
|
||||
|
||||
if lft.kind != rgt.kind:
|
||||
return Failure("Kind mismatch")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
|
||||
return Failure("Not the same type")
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
|
||||
return ActionList([ReplaceVariable(rgt, lft)])
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
|
||||
raise NotImplementedError # Should have been caught by kind check above
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
|
||||
if is_concrete(rgt):
|
||||
return Failure("Not the same type")
|
||||
|
||||
return Failure("Type shape mismatch")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
|
||||
return ActionList([ReplaceVariable(lft, rgt)])
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
|
||||
return ActionList([ReplaceVariable(lft, rgt)])
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
|
||||
if occurs(lft, rgt):
|
||||
return Failure("One type occurs in the other")
|
||||
|
||||
return ActionList([ReplaceVariable(lft, rgt)])
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
|
||||
return Failure("Not the same type constructor")
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
|
||||
return Failure("Not the same type constructor")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
|
||||
con_res = unify(lft.constructor, rgt.constructor)
|
||||
if isinstance(con_res, Failure):
|
||||
return con_res
|
||||
|
||||
arg_res = unify(lft.argument, rgt.argument)
|
||||
if isinstance(arg_res, Failure):
|
||||
return arg_res
|
||||
|
||||
return ActionList(con_res + arg_res)
|
||||
|
||||
|
||||
|
||||
return Failure('Not implemented')
|
||||
@ -305,8 +305,5 @@ def testEntry() -> i32:
|
||||
```
|
||||
|
||||
```py
|
||||
if TYPE_NAME.startswith('tuple_') or TYPE_NAME.startswith('static_array_') or TYPE_NAME.startswith('dynamic_array_'):
|
||||
expect_type_error('Not the same type constructor')
|
||||
else:
|
||||
expect_type_error('Not the same type')
|
||||
```
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from phasm.type5.solver import Type5SolverException
|
||||
|
||||
from ..helpers import Suite
|
||||
|
||||
|
||||
@ -38,24 +40,22 @@ def test_call_post_defined():
|
||||
code_py = """
|
||||
@exported
|
||||
def testEntry() -> i32:
|
||||
return helper(10, 3)
|
||||
return helper(13)
|
||||
|
||||
def helper(left: i32, right: i32) -> i32:
|
||||
return left - right
|
||||
def helper(left: i32) -> i32:
|
||||
return left
|
||||
"""
|
||||
|
||||
result = Suite(code_py).run_code()
|
||||
|
||||
assert 7 == result.returned_value
|
||||
assert 13 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@pytest.mark.skip('FIXME: Type checking')
|
||||
def test_call_invalid_type():
|
||||
code_py = """
|
||||
def helper(left: i32) -> i32:
|
||||
return left()
|
||||
"""
|
||||
|
||||
result = Suite(code_py).run_code()
|
||||
|
||||
assert 7 == result.returned_value
|
||||
with pytest.raises(Type5SolverException, match=r'i32 ~ Callable\[i32\]'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@ -91,8 +91,7 @@ def testEntry() -> i32:
|
||||
return action(double, 13.0)
|
||||
"""
|
||||
|
||||
match = r'Callable\[i32, i32\] ~ Callable\[f32, [^]]+\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@ -109,8 +108,7 @@ def testEntry() -> i32:
|
||||
return action(double, 13)
|
||||
"""
|
||||
|
||||
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[f32, i32\], p_[0-9]+, [^]]+\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@ -127,14 +125,14 @@ def testEntry() -> f32:
|
||||
return action(double, 13)
|
||||
"""
|
||||
|
||||
with pytest.raises(Type5SolverException, match='f32 ~ i32'):
|
||||
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_function_with_wrong_return_type_pass():
|
||||
code_py = """
|
||||
def double(left: i32) -> f32:
|
||||
return convert(left) * 2
|
||||
return convert(left) * 2.0
|
||||
|
||||
def action(applicable: Callable[i32, i32], left: i32) -> i32:
|
||||
return applicable(left)
|
||||
@ -144,8 +142,7 @@ def testEntry() -> i32:
|
||||
return action(double, 13)
|
||||
"""
|
||||
|
||||
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[i32, f32\], p_[0-9]+, [^]]+\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@ -179,12 +176,12 @@ def testEntry() -> i32:
|
||||
return action(double, 13, 14)
|
||||
"""
|
||||
|
||||
match = r'Callable\[Callable\[i32, i32, i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]'
|
||||
match = r'Callable\[i32, i32\] ~ i32'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_too_many_args_use():
|
||||
def test_sof_too_many_args_use_0():
|
||||
code_py = """
|
||||
def thirteen() -> i32:
|
||||
return 13
|
||||
@ -197,12 +194,30 @@ def testEntry() -> i32:
|
||||
return action(thirteen, 13)
|
||||
"""
|
||||
|
||||
match = r'Callable\[i32\] ~ Callable\[i32, p_[0-9]+\]'
|
||||
match = r'\(\) ~ i32'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code(verbose=True)
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_too_many_args_pass():
|
||||
def test_sof_too_many_args_use_1():
|
||||
code_py = """
|
||||
def thirteen(x: i32) -> i32:
|
||||
return x
|
||||
|
||||
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
|
||||
return applicable(left, right)
|
||||
|
||||
@exported
|
||||
def testEntry() -> i32:
|
||||
return action(thirteen, 13, 26)
|
||||
"""
|
||||
|
||||
match = r'i32 ~ Callable\[i32, i32\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code(verbose=True)
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_too_many_args_pass_0():
|
||||
code_py = """
|
||||
def double(left: i32) -> i32:
|
||||
return left * 2
|
||||
@ -215,6 +230,24 @@ def testEntry() -> i32:
|
||||
return action(double, 13, 14)
|
||||
"""
|
||||
|
||||
match = r'Callable\[Callable\[i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]'
|
||||
match = r'\(\) ~ i32'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_too_many_args_pass_1():
|
||||
code_py = """
|
||||
def double(left: i32, right: i32) -> i32:
|
||||
return left * right
|
||||
|
||||
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
|
||||
return applicable(left)
|
||||
|
||||
@exported
|
||||
def testEntry() -> i32:
|
||||
return action(double, 13, 14)
|
||||
"""
|
||||
|
||||
match = r'i32 ~ Callable\[i32, i32\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@ -168,12 +168,9 @@ def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
|
||||
return foldl(x, y, z)
|
||||
"""
|
||||
|
||||
match = {
|
||||
'i8': 'Type shape mismatch',
|
||||
'i8[3]': 'Kind mismatch',
|
||||
}
|
||||
match = 'Type shape mismatch'
|
||||
|
||||
with pytest.raises(Type5SolverException, match=match[in_typ]):
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user