Compare commits
3 Commits
a955d4fc31
...
3d6d279408
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d6d279408 | ||
| 71691d68e9 | |||
|
|
7df9d5af12 |
@ -38,7 +38,6 @@ def load(build: BuildBase[Any]) -> None:
|
||||
build.register_type_class(Sized)
|
||||
|
||||
def wasm_dynamic_array_len(g: WasmGenerator, tv_map: dict[str, TypeExpr]) -> None:
|
||||
print('tv_map', tv_map)
|
||||
del tv_map
|
||||
# The length is stored in the first 4 bytes
|
||||
g.i32.load()
|
||||
|
||||
@ -8,6 +8,7 @@ from ..type5.typeexpr import (
|
||||
TypeApplication,
|
||||
TypeConstructor,
|
||||
TypeExpr,
|
||||
TypeLevelNat,
|
||||
TypeVariable,
|
||||
)
|
||||
from ..type5.typerouter import TypeRouter
|
||||
@ -112,6 +113,9 @@ class TypeName(BuildTypeRouter[str]):
|
||||
def when_tuple(self, tp_args: list[TypeExpr]) -> str:
|
||||
return '(' + ', '.join(map(self, tp_args)) + ', )'
|
||||
|
||||
def when_type_level_nat(self, typ: TypeLevelNat) -> str:
|
||||
return str(typ.value)
|
||||
|
||||
def when_variable(self, typ: TypeVariable) -> str:
|
||||
return typ.name
|
||||
|
||||
|
||||
@ -67,7 +67,7 @@ def expression(inp: ourlang.Expression) -> str:
|
||||
return str(inp.variable.name)
|
||||
|
||||
if isinstance(inp, ourlang.BinaryOp):
|
||||
return f'{expression(inp.left)} {inp.operator.function.name} {expression(inp.right)}'
|
||||
return f'{expression(inp.left)} {inp.operator.name} {expression(inp.right)}'
|
||||
|
||||
if isinstance(inp, ourlang.FunctionCall):
|
||||
args = ', '.join(
|
||||
@ -75,10 +75,10 @@ def expression(inp: ourlang.Expression) -> str:
|
||||
for arg in inp.arguments
|
||||
)
|
||||
|
||||
if isinstance(inp.function_instance.function, ourlang.StructConstructor):
|
||||
return f'{inp.function_instance.function.struct_type5.name}({args})'
|
||||
if isinstance(inp.function, ourlang.StructConstructor):
|
||||
return f'{inp.function.struct_type5.name}({args})'
|
||||
|
||||
return f'{inp.function_instance.function.name}({args})'
|
||||
return f'{inp.function.name}({args})'
|
||||
|
||||
if isinstance(inp, ourlang.FunctionReference):
|
||||
return str(inp.function.name)
|
||||
|
||||
@ -11,7 +11,14 @@ from .build.typerouter import BuildTypeRouter
|
||||
from .stdlib import alloc as stdlib_alloc
|
||||
from .stdlib import types as stdlib_types
|
||||
from .type5.constrainedexpr import ConstrainedExpr
|
||||
from .type5.typeexpr import AtomicType, TypeApplication, TypeExpr, is_concrete
|
||||
from .type5.typeexpr import (
|
||||
AtomicType,
|
||||
TypeApplication,
|
||||
TypeExpr,
|
||||
TypeVariable,
|
||||
is_concrete,
|
||||
replace_variable,
|
||||
)
|
||||
from .wasm import (
|
||||
WasmTypeFloat32,
|
||||
WasmTypeFloat64,
|
||||
@ -153,32 +160,94 @@ def expression_subscript_tuple(wgn: WasmGenerator, mod: ourlang.Module[WasmGener
|
||||
expression(wgn, mod, inp.varref)
|
||||
wgn.add_statement(el_type_info.wasm_load_func, f'offset={offset}')
|
||||
|
||||
def expression_subscript_operator(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Subscript) -> None:
|
||||
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
ftp5 = mod.build.type_classes['Subscriptable'].operators['[]']
|
||||
fn_args = mod.build.type5_is_function(ftp5)
|
||||
assert fn_args is not None
|
||||
t_a = fn_args[0]
|
||||
assert isinstance(t_a, TypeApplication)
|
||||
t = t_a.constructor
|
||||
a = t_a.argument
|
||||
|
||||
assert isinstance(t, TypeVariable)
|
||||
assert isinstance(a, TypeVariable)
|
||||
|
||||
assert isinstance(inp.varref.type5, TypeApplication)
|
||||
t_expr = inp.varref.type5.constructor
|
||||
a_expr = inp.varref.type5.argument
|
||||
|
||||
_expression_binary_operator_or_function_call(
|
||||
wgn,
|
||||
mod,
|
||||
ourlang.BuiltinFunction('[]', ftp5),
|
||||
{
|
||||
t: t_expr,
|
||||
a: a_expr,
|
||||
},
|
||||
[inp.varref, inp.index],
|
||||
inp.type5,
|
||||
)
|
||||
|
||||
def expression_binary_op(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.BinaryOp) -> None:
|
||||
expression_function_call(wgn, mod, _binary_op_to_function(inp))
|
||||
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
_expression_binary_operator_or_function_call(
|
||||
wgn,
|
||||
mod,
|
||||
inp.operator,
|
||||
inp.polytype_substitutions,
|
||||
[inp.left, inp.right],
|
||||
inp.type5,
|
||||
)
|
||||
|
||||
def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.FunctionCall) -> None:
|
||||
for arg in inp.arguments:
|
||||
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
_expression_binary_operator_or_function_call(
|
||||
wgn,
|
||||
mod,
|
||||
inp.function,
|
||||
inp.polytype_substitutions,
|
||||
inp.arguments,
|
||||
inp.type5,
|
||||
)
|
||||
|
||||
def _expression_binary_operator_or_function_call(
|
||||
wgn: WasmGenerator,
|
||||
mod: ourlang.Module[WasmGenerator],
|
||||
function: ourlang.Function | ourlang.FunctionParam,
|
||||
polytype_substitutions: dict[TypeVariable, TypeExpr],
|
||||
arguments: list[ourlang.Expression],
|
||||
ret_type5: TypeExpr,
|
||||
) -> None:
|
||||
for arg in arguments:
|
||||
expression(wgn, mod, arg)
|
||||
|
||||
if isinstance(inp.function_instance.function, ourlang.BuiltinFunction):
|
||||
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
|
||||
if isinstance(function, ourlang.BuiltinFunction):
|
||||
ftp5 = function.type5
|
||||
if isinstance(ftp5, ConstrainedExpr):
|
||||
cexpr = ftp5
|
||||
ftp5 = ftp5.expr
|
||||
for tvar in cexpr.variables:
|
||||
ftp5 = replace_variable(ftp5, tvar, polytype_substitutions[tvar])
|
||||
assert _is_concrete(ftp5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
try:
|
||||
method_type, method_router = mod.build.methods[inp.function_instance.function.name]
|
||||
method_type, method_router = mod.build.methods[function.name]
|
||||
except KeyError:
|
||||
method_type, method_router = mod.build.operators[inp.function_instance.function.name]
|
||||
method_type, method_router = mod.build.operators[function.name]
|
||||
|
||||
impl_lookup = method_router.get((inp.function_instance.type5, ))
|
||||
assert impl_lookup is not None, (inp.function_instance.function.name, inp.function_instance.type5, )
|
||||
impl_lookup = method_router.get((ftp5, ))
|
||||
assert impl_lookup is not None, (function.name, ftp5, )
|
||||
kwargs, impl = impl_lookup
|
||||
impl(wgn, kwargs)
|
||||
return
|
||||
|
||||
if isinstance(inp.function_instance.function, ourlang.FunctionParam):
|
||||
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
|
||||
|
||||
fn_args = mod.build.type5_is_function(inp.function_instance.type5)
|
||||
assert fn_args is not None
|
||||
if isinstance(function, ourlang.FunctionParam):
|
||||
fn_args = mod.build.type5_is_function(function.type5)
|
||||
assert fn_args is not None, function.type5
|
||||
|
||||
params = [
|
||||
type5(mod, x)
|
||||
@ -187,11 +256,15 @@ def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerat
|
||||
|
||||
result = params.pop()
|
||||
|
||||
wgn.add_statement('local.get', '${}'.format(inp.function_instance.function.name))
|
||||
wgn.add_statement('local.get', '${}'.format(function.name))
|
||||
wgn.call_indirect(params=params, result=result)
|
||||
return
|
||||
|
||||
wgn.call(inp.function_instance.function.name)
|
||||
# TODO: Do similar subsitutions like we do for BuiltinFunction
|
||||
# when we get user space polymorphic functions
|
||||
# And then do similar lookup, and ensure we generate code for that variant
|
||||
|
||||
wgn.call(function.name)
|
||||
|
||||
def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Expression) -> None:
|
||||
"""
|
||||
@ -283,22 +356,7 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourl
|
||||
expression_subscript_tuple(wgn, mod, inp)
|
||||
return
|
||||
|
||||
inp_as_fc = ourlang.FunctionCall(
|
||||
ourlang.FunctionInstance(
|
||||
ourlang.BuiltinFunction('[]', mod.build.type_classes['Subscriptable'].operators['[]']),
|
||||
inp.sourceref,
|
||||
),
|
||||
inp.sourceref,
|
||||
)
|
||||
inp_as_fc.arguments = [inp.varref, inp.index]
|
||||
inp_as_fc.function_instance.type5 = mod.build.type5_make_function([
|
||||
inp.varref.type5,
|
||||
inp.index.type5,
|
||||
inp.type5,
|
||||
])
|
||||
inp_as_fc.type5 = inp.type5
|
||||
|
||||
expression_function_call(wgn, mod, inp_as_fc)
|
||||
expression_subscript_operator(wgn, mod, inp)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.AccessStructMember):
|
||||
@ -326,11 +384,11 @@ def statement_return(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun
|
||||
# Support tail calls
|
||||
# https://github.com/WebAssembly/tail-call
|
||||
# These help a lot with some functional programming techniques
|
||||
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function_instance.function is fun:
|
||||
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function is fun:
|
||||
for arg in inp.value.arguments:
|
||||
expression(wgn, mod, arg)
|
||||
|
||||
wgn.add_statement('return_call', '${}'.format(inp.value.function_instance.function.name))
|
||||
wgn.add_statement('return_call', '${}'.format(inp.value.function.name))
|
||||
return
|
||||
|
||||
expression(wgn, mod, inp.value)
|
||||
@ -607,14 +665,3 @@ def _type5_struct_offset(
|
||||
result += build.type5_alloc_size_member(memtyp)
|
||||
|
||||
raise RuntimeError('Member not found')
|
||||
|
||||
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
|
||||
"""
|
||||
For compilation purposes, a binary operator is just a function call.
|
||||
|
||||
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
|
||||
"""
|
||||
assert inp.sourceref is not None # TODO: sourceref required
|
||||
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
|
||||
call.arguments = [inp.left, inp.right]
|
||||
return call
|
||||
|
||||
@ -157,52 +157,39 @@ class BinaryOp(Expression):
|
||||
"""
|
||||
A binary operator expression within a statement
|
||||
"""
|
||||
__slots__ = ('operator', 'left', 'right', )
|
||||
__slots__ = ('operator', 'polytype_substitutions', 'left', 'right', )
|
||||
|
||||
operator: FunctionInstance
|
||||
operator: Function | FunctionParam
|
||||
polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]
|
||||
left: Expression
|
||||
right: Expression
|
||||
|
||||
def __init__(self, operator: FunctionInstance, left: Expression, right: Expression, sourceref: SourceRef) -> None:
|
||||
def __init__(self, operator: Function | FunctionParam, left: Expression, right: Expression, sourceref: SourceRef) -> None:
|
||||
super().__init__(sourceref=sourceref)
|
||||
|
||||
self.operator = operator
|
||||
self.polytype_substitutions = {}
|
||||
self.left = left
|
||||
self.right = right
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'BinaryOp({repr(self.operator)}, {repr(self.left)}, {repr(self.right)})'
|
||||
|
||||
class FunctionInstance(Expression):
|
||||
"""
|
||||
When calling a polymorphic function with concrete arguments, we can generate
|
||||
code for that specific instance of the function.
|
||||
"""
|
||||
__slots__ = ('function', )
|
||||
|
||||
function: Union['Function', 'FunctionParam']
|
||||
|
||||
def __init__(self, function: Union['Function', 'FunctionParam'], sourceref: SourceRef) -> None:
|
||||
super().__init__(sourceref=sourceref)
|
||||
|
||||
self.function = function
|
||||
|
||||
class FunctionCall(Expression):
|
||||
"""
|
||||
A function call expression within a statement
|
||||
"""
|
||||
__slots__ = ('function_instance', 'arguments', )
|
||||
__slots__ = ('function', 'polytype_substitutions', 'arguments', )
|
||||
|
||||
function_instance: FunctionInstance
|
||||
# TODO: FunctionInstance is wrong - we should have
|
||||
# substitutions: dict[TypeVariable, TypeExpr]
|
||||
# And it should have the same variables as the polytype (ConstrainedExpr) for function
|
||||
function: Function | FunctionParam
|
||||
polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]
|
||||
arguments: List[Expression]
|
||||
|
||||
def __init__(self, function_instance: FunctionInstance, sourceref: SourceRef) -> None:
|
||||
def __init__(self, function: Function | FunctionParam, sourceref: SourceRef) -> None:
|
||||
super().__init__(sourceref=sourceref)
|
||||
|
||||
self.function_instance = function_instance
|
||||
self.function = function
|
||||
self.polytype_substitutions = {}
|
||||
self.arguments = []
|
||||
|
||||
class FunctionReference(Expression):
|
||||
|
||||
@ -18,7 +18,6 @@ from .ourlang import (
|
||||
Expression,
|
||||
Function,
|
||||
FunctionCall,
|
||||
FunctionInstance,
|
||||
FunctionParam,
|
||||
FunctionReference,
|
||||
Module,
|
||||
@ -404,7 +403,7 @@ class OurVisitor[G]:
|
||||
raise NotImplementedError(f'Operator {operator}')
|
||||
|
||||
return BinaryOp(
|
||||
FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)),
|
||||
BuiltinFunction(operator, module.operators[operator]),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
|
||||
srf(module, node),
|
||||
@ -433,7 +432,7 @@ class OurVisitor[G]:
|
||||
raise NotImplementedError(f'Operator {operator}')
|
||||
|
||||
return BinaryOp(
|
||||
FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)),
|
||||
BuiltinFunction(operator, module.operators[operator]),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
|
||||
srf(module, node),
|
||||
@ -510,7 +509,7 @@ class OurVisitor[G]:
|
||||
|
||||
func = module.functions[node.func.id]
|
||||
|
||||
result = FunctionCall(FunctionInstance(func, srf(module, node)), sourceref=srf(module, node))
|
||||
result = FunctionCall(func, sourceref=srf(module, node))
|
||||
result.arguments.extend(
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
|
||||
for arg_expr in node.args
|
||||
|
||||
@ -42,9 +42,8 @@ class ConstrainedExpr:
|
||||
|
||||
def instantiate_constrained(
|
||||
constrainedexpr: ConstrainedExpr,
|
||||
known_map: dict[TypeVariable, TypeVariable],
|
||||
make_variable: Callable[[KindExpr, str], TypeVariable],
|
||||
) -> ConstrainedExpr:
|
||||
) -> tuple[ConstrainedExpr, dict[TypeVariable, TypeVariable]]:
|
||||
"""
|
||||
Instantiates a type expression and its constraints
|
||||
"""
|
||||
@ -61,4 +60,4 @@ def instantiate_constrained(
|
||||
x.instantiate(known_map)
|
||||
for x in constrainedexpr.constraints
|
||||
)
|
||||
return ConstrainedExpr(constrainedexpr.variables, expr, constraints)
|
||||
return ConstrainedExpr(constrainedexpr.variables, expr, constraints), known_map
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Callable, Iterable, Protocol, Sequence
|
||||
from typing import Any, Callable, Iterable, Protocol, Sequence, TypeAlias
|
||||
|
||||
from ..build.base import BuildBase
|
||||
from ..ourlang import SourceRef
|
||||
@ -9,13 +9,16 @@ from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt6
|
||||
from .kindexpr import KindExpr, Star
|
||||
from .record import Record
|
||||
from .typeexpr import (
|
||||
AtomicType,
|
||||
TypeApplication,
|
||||
TypeConstructor,
|
||||
TypeExpr,
|
||||
TypeLevelNat,
|
||||
TypeVariable,
|
||||
is_concrete,
|
||||
occurs,
|
||||
replace_variable,
|
||||
)
|
||||
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
|
||||
|
||||
|
||||
class ExpressionProtocol(Protocol):
|
||||
@ -28,50 +31,95 @@ class ExpressionProtocol(Protocol):
|
||||
The type to update
|
||||
"""
|
||||
|
||||
PolytypeSubsituteMap: TypeAlias = dict[TypeVariable, TypeExpr]
|
||||
|
||||
class Context:
|
||||
__slots__ = ("build", "placeholder_update", )
|
||||
__slots__ = ("build", "placeholder_update", "ptst_update", )
|
||||
|
||||
build: BuildBase[Any]
|
||||
placeholder_update: dict[TypeVariable, ExpressionProtocol | None]
|
||||
ptst_update: dict[TypeVariable, tuple[PolytypeSubsituteMap, TypeVariable]]
|
||||
|
||||
def __init__(self, build: BuildBase[Any]) -> None:
|
||||
self.build = build
|
||||
self.placeholder_update = {}
|
||||
self.ptst_update = {}
|
||||
|
||||
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star(), prefix: str = 'p') -> TypeVariable:
|
||||
res = TypeVariable(kind, f"{prefix}_{len(self.placeholder_update)}")
|
||||
self.placeholder_update[res] = arg
|
||||
return res
|
||||
|
||||
def register_polytype_subsitutes(self, tvar: TypeVariable, arg: PolytypeSubsituteMap, orig_var: TypeVariable) -> None:
|
||||
"""
|
||||
When `tvar` gets subsituted, also set the result in arg with orig_var as key
|
||||
|
||||
e.g.
|
||||
|
||||
(-) :: Callable[a, a, a]
|
||||
|
||||
def foo() -> u32:
|
||||
return 2 - 1
|
||||
|
||||
During typing, we instantiate a into a_3, and get the following constraints:
|
||||
- u8 ~ p_1
|
||||
- u8 ~ p_2
|
||||
- Exists NatNum a_3
|
||||
- Callable[a_3, a_3, a_3] ~ Callable[p_1, p_2, p_0]
|
||||
- u8 ~ p_0
|
||||
|
||||
When we resolve a_3, then on the call to `-`, we should note that a_3 got resolved
|
||||
to u32. But we need to use `a` as key, since that's what's used on the definition
|
||||
"""
|
||||
assert tvar in self.placeholder_update
|
||||
assert tvar not in self.ptst_update
|
||||
self.ptst_update[tvar] = (arg, orig_var)
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Failure:
|
||||
"""
|
||||
Both types are already different - cannot be unified.
|
||||
"""
|
||||
msg: str
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReplaceVariable:
|
||||
var: TypeVariable
|
||||
typ: TypeExpr
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CheckResult:
|
||||
# TODO: Refactor this, don't think we use most of the variants
|
||||
_: dataclasses.KW_ONLY
|
||||
done: bool = True
|
||||
actions: ActionList = dataclasses.field(default_factory=ActionList)
|
||||
replace: ReplaceVariable | None = None
|
||||
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
|
||||
failures: list[Failure] = dataclasses.field(default_factory=list)
|
||||
|
||||
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
||||
if not self.done and not self.actions and not self.new_constraints and not self.failures:
|
||||
if not self.done and not self.replace and not self.new_constraints and not self.failures:
|
||||
return '(skip for now)'
|
||||
|
||||
if self.done and not self.actions and not self.new_constraints and not self.failures:
|
||||
if self.done and not self.replace and not self.new_constraints and not self.failures:
|
||||
return '(ok)'
|
||||
|
||||
if self.done and self.actions and not self.new_constraints and not self.failures:
|
||||
return self.actions.to_str(type_namer)
|
||||
if self.done and self.replace and not self.new_constraints and not self.failures:
|
||||
return f'{{{self.replace.var.name} := {type_namer(self.replace.typ)}}}'
|
||||
|
||||
if self.done and not self.actions and self.new_constraints and not self.failures:
|
||||
if self.done and not self.replace and self.new_constraints and not self.failures:
|
||||
return f'(got {len(self.new_constraints)} new constraints)'
|
||||
|
||||
if self.done and not self.actions and not self.new_constraints and self.failures:
|
||||
if self.done and not self.replace and not self.new_constraints and self.failures:
|
||||
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
|
||||
|
||||
return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}'
|
||||
return f'{self.done} {self.replace} {self.new_constraints} {self.failures}'
|
||||
|
||||
def skip_for_now() -> CheckResult:
|
||||
return CheckResult(done=False)
|
||||
|
||||
def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult:
|
||||
return CheckResult(replace=ReplaceVariable(var, typ))
|
||||
|
||||
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
|
||||
return CheckResult(new_constraints=list(lst))
|
||||
|
||||
@ -94,12 +142,8 @@ class ConstraintBase:
|
||||
def check(self) -> CheckResult:
|
||||
raise NotImplementedError(self)
|
||||
|
||||
def apply(self, action: Action) -> None:
|
||||
if isinstance(action, ReplaceVariable):
|
||||
self.replace_variable(action.var, action.typ)
|
||||
return
|
||||
|
||||
raise NotImplementedError(action)
|
||||
def complexity(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
pass
|
||||
@ -142,6 +186,9 @@ class FromLiteralInteger(ConstraintBase):
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
self.type5 = replace_variable(self.type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
||||
|
||||
@ -175,8 +222,11 @@ class FromLiteralFloat(ConstraintBase):
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
self.type5 = replace_variable(self.type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
||||
return f'FromLiteralFloat {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
||||
|
||||
class FromLiteralBytes(ConstraintBase):
|
||||
__slots__ = ('type5', 'literal', )
|
||||
@ -203,32 +253,125 @@ class FromLiteralBytes(ConstraintBase):
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
self.type5 = replace_variable(self.type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'FromLiteralBytes {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
||||
|
||||
class UnifyTypesConstraint(ConstraintBase):
|
||||
__slots__ = ("lft", "rgt",)
|
||||
__slots__ = ("lft", "rgt", "prefix", )
|
||||
|
||||
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr) -> None:
|
||||
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr, prefix: str | None = None) -> None:
|
||||
super().__init__(ctx, sourceref)
|
||||
|
||||
self.lft = lft
|
||||
self.rgt = rgt
|
||||
self.prefix = prefix
|
||||
|
||||
def check(self) -> CheckResult:
|
||||
result = unify(self.lft, self.rgt)
|
||||
lft = self.lft
|
||||
rgt = self.rgt
|
||||
|
||||
if isinstance(result, Failure):
|
||||
return CheckResult(failures=[result])
|
||||
if lft == self.rgt:
|
||||
return ok()
|
||||
|
||||
return CheckResult(actions=result)
|
||||
if lft.kind != rgt.kind:
|
||||
return fail("Kind mismatch")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
|
||||
return fail("Not the same type")
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
|
||||
return replace(rgt, lft)
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
|
||||
raise NotImplementedError # Should have been caught by kind check above
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
|
||||
return fail("Not the same type" if is_concrete(rgt) else "Type shape mismatch")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
|
||||
return replace(lft, rgt)
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
|
||||
return replace(lft, rgt)
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
|
||||
return replace(lft, rgt)
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
|
||||
if occurs(lft, rgt):
|
||||
return fail("One type occurs in the other")
|
||||
|
||||
return replace(lft, rgt)
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
|
||||
raise NotImplementedError # Should have been caught by kind check above
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
|
||||
return replace(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
|
||||
return fail("Not the same type constructor")
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
|
||||
return fail("Not the same type constructor")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
|
||||
return fail("Not the same type" if is_concrete(lft) else "Type shape mismatch")
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
|
||||
if occurs(rgt, lft):
|
||||
return fail("One type occurs in the other")
|
||||
|
||||
return replace(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
|
||||
return fail("Not the same type constructor")
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
|
||||
|
||||
## USABILITY HACK
|
||||
## Often, we have two type applications in the same go
|
||||
## If so, resolve it in a single step
|
||||
## (Helps with debugging function unification)
|
||||
## This *should* not affect the actual type unification
|
||||
## It's just one less call to UnifyTypesConstraint.check
|
||||
if isinstance(lft.constructor, TypeApplication) and isinstance(rgt.constructor, TypeApplication):
|
||||
return new_constraints([
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.constructor, rgt.constructor.constructor),
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.argument, rgt.constructor.argument),
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
|
||||
])
|
||||
|
||||
|
||||
return new_constraints([
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor, rgt.constructor),
|
||||
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
|
||||
])
|
||||
|
||||
|
||||
raise NotImplementedError(lft, rgt)
|
||||
|
||||
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
||||
self.lft = replace_variable(self.lft, var, typ)
|
||||
self.rgt = replace_variable(self.rgt, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return complexity(self.lft) + complexity(self.rgt)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
|
||||
prefix = f'{self.prefix} :: ' if self.prefix else ''
|
||||
return f"{prefix}{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
|
||||
|
||||
class CanBeSubscriptedConstraint(ConstraintBase):
|
||||
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
|
||||
@ -290,6 +433,9 @@ class CanBeSubscriptedConstraint(ConstraintBase):
|
||||
self.container_type5 = replace_variable(self.container_type5, var, typ)
|
||||
self.index_type5 = replace_variable(self.index_type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.ret_type5) + complexity(self.container_type5) + complexity(self.index_type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
|
||||
|
||||
@ -333,6 +479,9 @@ class CanAccessStructMemberConstraint(ConstraintBase):
|
||||
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
|
||||
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.ret_type5) + complexity(self.struct_type5)
|
||||
|
||||
def __str__(self) -> str:
|
||||
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
|
||||
member_dict = dict(st_args or [])
|
||||
@ -404,6 +553,9 @@ class FromTupleConstraint(ConstraintBase):
|
||||
for x in self.member_type5_list
|
||||
]
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + complexity(self.ret_type5) + sum(complexity(x) for x in self.member_type5_list)
|
||||
|
||||
def __str__(self) -> str:
|
||||
args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list)
|
||||
return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )'
|
||||
@ -450,6 +602,24 @@ class TypeClassInstanceExistsConstraint(ConstraintBase):
|
||||
for x in self.arg_list
|
||||
]
|
||||
|
||||
def complexity(self) -> int:
|
||||
return 100 + sum(complexity(x) for x in self.arg_list)
|
||||
|
||||
def __str__(self) -> str:
|
||||
args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list)
|
||||
return f'Exists {self.typeclass} {args}'
|
||||
|
||||
def complexity(expr: TypeExpr) -> int:
|
||||
if isinstance(expr, AtomicType | TypeLevelNat):
|
||||
return 1
|
||||
|
||||
if isinstance(expr, TypeConstructor):
|
||||
return 2
|
||||
|
||||
if isinstance(expr, TypeVariable):
|
||||
return 5
|
||||
|
||||
if isinstance(expr, TypeApplication):
|
||||
return complexity(expr.constructor) + complexity(expr.argument)
|
||||
|
||||
raise NotImplementedError(expr)
|
||||
|
||||
@ -16,7 +16,7 @@ from .constraints import (
|
||||
UnifyTypesConstraint,
|
||||
)
|
||||
from .kindexpr import KindExpr, Star
|
||||
from .typeexpr import TypeApplication, TypeVariable, instantiate
|
||||
from .typeexpr import TypeApplication, TypeExpr, TypeVariable, is_concrete
|
||||
|
||||
ConstraintGenerator = Generator[ConstraintBase, None, None]
|
||||
|
||||
@ -90,14 +90,41 @@ def expression_constant(ctx: Context, inp: ourlang.Constant, phft: TypeVariable)
|
||||
raise NotImplementedError(inp)
|
||||
|
||||
def expression_variable_reference(ctx: Context, inp: ourlang.VariableReference, phft: TypeVariable) -> ConstraintGenerator:
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft)
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft, prefix=inp.variable.name)
|
||||
|
||||
def expression_binary_operator(ctx: Context, inp: ourlang.BinaryOp, phft: TypeVariable) -> ConstraintGenerator:
|
||||
yield from expression_function_call(ctx, _binary_op_to_function(inp), phft)
|
||||
yield from _expression_binary_operator_or_function_call(
|
||||
ctx,
|
||||
inp.operator,
|
||||
inp.polytype_substitutions,
|
||||
[inp.left, inp.right],
|
||||
inp.sourceref,
|
||||
f'({inp.operator.name})',
|
||||
phft,
|
||||
)
|
||||
|
||||
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: TypeVariable) -> ConstraintGenerator:
|
||||
yield from _expression_binary_operator_or_function_call(
|
||||
ctx,
|
||||
inp.function,
|
||||
inp.polytype_substitutions,
|
||||
inp.arguments,
|
||||
inp.sourceref,
|
||||
inp.function.name,
|
||||
phft,
|
||||
)
|
||||
|
||||
def _expression_binary_operator_or_function_call(
|
||||
ctx: Context,
|
||||
function: ourlang.Function | ourlang.FunctionParam,
|
||||
polytype_substitutions: dict[TypeVariable, TypeExpr],
|
||||
arguments: list[ourlang.Expression],
|
||||
sourceref: ourlang.SourceRef,
|
||||
function_name: str,
|
||||
phft: TypeVariable,
|
||||
) -> ConstraintGenerator:
|
||||
arg_typ_list = []
|
||||
for arg in inp.arguments:
|
||||
for arg in arguments:
|
||||
arg_tv = ctx.make_placeholder(arg)
|
||||
yield from expression(ctx, arg, arg_tv)
|
||||
arg_typ_list.append(arg_tv)
|
||||
@ -105,34 +132,28 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Type
|
||||
def make_placeholder(x: KindExpr, p: str) -> TypeVariable:
|
||||
return ctx.make_placeholder(kind=x, prefix=p)
|
||||
|
||||
ftp5 = inp.function_instance.function.type5
|
||||
ftp5 = function.type5
|
||||
assert ftp5 is not None
|
||||
if isinstance(ftp5, ConstrainedExpr):
|
||||
ftp5 = instantiate_constrained(ftp5, {}, make_placeholder)
|
||||
ftp5, phft_lookup = instantiate_constrained(ftp5, make_placeholder)
|
||||
|
||||
for orig_tvar, tvar in phft_lookup.items():
|
||||
ctx.register_polytype_subsitutes(tvar, polytype_substitutions, orig_tvar)
|
||||
|
||||
for type_constraint in ftp5.constraints:
|
||||
if isinstance(type_constraint, TypeClassConstraint):
|
||||
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, type_constraint.cls.name, type_constraint.variables)
|
||||
yield TypeClassInstanceExistsConstraint(ctx, sourceref, type_constraint.cls.name, type_constraint.variables)
|
||||
continue
|
||||
|
||||
raise NotImplementedError(type_constraint)
|
||||
|
||||
ftp5 = ftp5.expr
|
||||
else:
|
||||
ftp5 = instantiate(ftp5, {})
|
||||
|
||||
# We need an extra placeholder so that the inp.function_instance gets updated
|
||||
phft2 = ctx.make_placeholder(inp.function_instance)
|
||||
yield UnifyTypesConstraint(
|
||||
ctx,
|
||||
inp.sourceref,
|
||||
ftp5,
|
||||
phft2,
|
||||
)
|
||||
assert is_concrete(ftp5)
|
||||
|
||||
expr_type = ctx.build.type5_make_function(arg_typ_list + [phft])
|
||||
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, phft2, expr_type)
|
||||
yield UnifyTypesConstraint(ctx, sourceref, ftp5, expr_type, prefix=function_name)
|
||||
|
||||
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: TypeVariable) -> ConstraintGenerator:
|
||||
assert inp.function.type5 is not None # Todo: Make not nullable
|
||||
@ -141,7 +162,7 @@ def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference,
|
||||
if isinstance(ftp5, ConstrainedExpr):
|
||||
ftp5 = ftp5.expr
|
||||
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft)
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft, prefix=inp.function.name)
|
||||
|
||||
def expression_tuple_instantiation(ctx: Context, inp: ourlang.TupleInstantiation, phft: TypeVariable) -> ConstraintGenerator:
|
||||
arg_typ_list = []
|
||||
@ -221,7 +242,7 @@ def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement
|
||||
type5 = io_arg
|
||||
|
||||
yield from expression(ctx, inp.value, phft)
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft)
|
||||
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft, prefix=f'{fun.name} returns')
|
||||
|
||||
def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf) -> ConstraintGenerator:
|
||||
test_phft = ctx.make_placeholder(inp.test)
|
||||
@ -295,14 +316,3 @@ def module(ctx: Context, inp: ourlang.Module[Any]) -> ConstraintGenerator:
|
||||
yield from function(ctx, func)
|
||||
|
||||
# TODO: Generalize?
|
||||
|
||||
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
|
||||
"""
|
||||
For typing purposes, a binary operator is just a function call.
|
||||
|
||||
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
|
||||
"""
|
||||
assert inp.sourceref is not None # TODO: sourceref required
|
||||
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
|
||||
call.arguments = [inp.left, inp.right]
|
||||
return call
|
||||
|
||||
@ -3,8 +3,7 @@ from typing import Any
|
||||
from ..ourlang import Module
|
||||
from .constraints import ConstraintBase, Context
|
||||
from .fromast import phasm_type5_generate_constraints
|
||||
from .typeexpr import TypeExpr, TypeVariable, replace_variable
|
||||
from .unify import ReplaceVariable
|
||||
from .typeexpr import TypeExpr, TypeVariable, is_concrete, replace_variable
|
||||
|
||||
MAX_RESTACK_COUNT = 100
|
||||
|
||||
@ -31,8 +30,7 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
|
||||
print("Validating")
|
||||
|
||||
new_constraint_list: list[ConstraintBase] = []
|
||||
while constraint_list:
|
||||
constraint = constraint_list.pop(0)
|
||||
for constraint in sorted(constraint_list, key=lambda x: x.complexity()):
|
||||
result = constraint.check()
|
||||
|
||||
if verbose:
|
||||
@ -44,29 +42,15 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
|
||||
# Means it checks out and we don't need do anything
|
||||
continue
|
||||
|
||||
while result.actions:
|
||||
action = result.actions.pop(0)
|
||||
if result.replace is not None:
|
||||
action_var = result.replace.var
|
||||
assert action_var not in placeholder_types # When does this happen?
|
||||
|
||||
if isinstance(action, ReplaceVariable):
|
||||
action_var: TypeExpr = action.var
|
||||
while isinstance(action_var, TypeVariable) and action_var in placeholder_types:
|
||||
# TODO: Does this still happen?
|
||||
action_var = placeholder_types[action_var]
|
||||
action_typ: TypeExpr = result.replace.typ
|
||||
assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen?
|
||||
|
||||
action_typ: TypeExpr = action.typ
|
||||
while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types:
|
||||
# TODO: Does this still happen?
|
||||
action_typ = placeholder_types[action_typ]
|
||||
assert action_var != action_typ # When does this happen?
|
||||
|
||||
# print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ))
|
||||
|
||||
if action_var == action_typ:
|
||||
continue
|
||||
|
||||
if not isinstance(action_var, TypeVariable) and isinstance(action_typ, TypeVariable):
|
||||
action_typ, action_var = action_var, action_typ
|
||||
|
||||
if isinstance(action_var, TypeVariable):
|
||||
# Ensure all existing found types are updated
|
||||
placeholder_types = {
|
||||
k: replace_variable(v, action_var, action_typ)
|
||||
@ -85,20 +69,14 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
|
||||
if verbose and old_str != new_str:
|
||||
print(f"{oth_const.sourceref!s} => - {old_str!s}")
|
||||
print(f"{oth_const.sourceref!s} => + {new_str!s}")
|
||||
continue
|
||||
|
||||
error_list.append((str(constraint.sourceref), str(constraint), "Not the same type", ))
|
||||
if verbose:
|
||||
print(f"{constraint.sourceref!s} => ERR: Conflict in applying {action.to_str(inp.build.type5_name)}")
|
||||
continue
|
||||
|
||||
# Action of unsupported type
|
||||
raise NotImplementedError(action)
|
||||
|
||||
for failure in result.failures:
|
||||
error_list.append((str(constraint.sourceref), str(constraint), failure.msg, ))
|
||||
|
||||
new_constraint_list.extend(result.new_constraints)
|
||||
if verbose:
|
||||
for new_const in result.new_constraints:
|
||||
print(f"{oth_const.sourceref!s} => + {new_const!s}")
|
||||
|
||||
if result.done:
|
||||
continue
|
||||
@ -124,8 +102,11 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
|
||||
if expression is None:
|
||||
continue
|
||||
|
||||
new_type5 = placeholder_types[placeholder]
|
||||
while isinstance(new_type5, TypeVariable):
|
||||
new_type5 = placeholder_types[new_type5]
|
||||
resolved_type5 = placeholder_types[placeholder]
|
||||
assert is_concrete(resolved_type5) # When does this happen?
|
||||
expression.type5 = resolved_type5
|
||||
|
||||
expression.type5 = new_type5
|
||||
for placeholder, (ptst_map, orig_tvar) in ctx.ptst_update.items():
|
||||
resolved_type5 = placeholder_types[placeholder]
|
||||
assert is_concrete(resolved_type5) # When does this happen?
|
||||
ptst_map[orig_tvar] = resolved_type5
|
||||
|
||||
@ -4,6 +4,7 @@ from .typeexpr import (
|
||||
TypeApplication,
|
||||
TypeConstructor,
|
||||
TypeExpr,
|
||||
TypeLevelNat,
|
||||
TypeVariable,
|
||||
)
|
||||
|
||||
@ -21,6 +22,9 @@ class TypeRouter[T]:
|
||||
def when_record(self, typ: Record) -> T:
|
||||
raise NotImplementedError(typ)
|
||||
|
||||
def when_type_level_nat(self, typ: TypeLevelNat) -> T:
|
||||
raise NotImplementedError(typ)
|
||||
|
||||
def when_variable(self, typ: TypeVariable) -> T:
|
||||
raise NotImplementedError(typ)
|
||||
|
||||
@ -37,6 +41,9 @@ class TypeRouter[T]:
|
||||
if isinstance(typ, TypeConstructor):
|
||||
return self.when_constructor(typ)
|
||||
|
||||
if isinstance(typ, TypeLevelNat):
|
||||
return self.when_type_level_nat(typ)
|
||||
|
||||
if isinstance(typ, TypeVariable):
|
||||
return self.when_variable(typ)
|
||||
|
||||
|
||||
@ -1,128 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable
|
||||
|
||||
from .typeexpr import (
|
||||
AtomicType,
|
||||
TypeApplication,
|
||||
TypeConstructor,
|
||||
TypeExpr,
|
||||
TypeVariable,
|
||||
is_concrete,
|
||||
occurs,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Failure:
|
||||
"""
|
||||
Both types are already different - cannot be unified.
|
||||
"""
|
||||
msg: str
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
class ActionList(list[Action]):
|
||||
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
||||
return '{' + ', '.join((x.to_str(type_namer) for x in self)) + '}'
|
||||
|
||||
UnifyResult = Failure | ActionList
|
||||
|
||||
@dataclass
|
||||
class ReplaceVariable(Action):
|
||||
var: TypeVariable
|
||||
typ: TypeExpr
|
||||
|
||||
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
||||
return f'{self.var.name} := {type_namer(self.typ)}'
|
||||
|
||||
def unify(lft: TypeExpr, rgt: TypeExpr) -> UnifyResult:
|
||||
"""
|
||||
Be warned: This only matches type variables with other variables or types
|
||||
- it does not apply substituions nor does it validate if the matching
|
||||
pairs are correct.
|
||||
|
||||
TODO: Remove this. It should be part of UnifyTypesConstraint
|
||||
and should just generate new constraints for applications.
|
||||
"""
|
||||
if lft == rgt:
|
||||
return ActionList()
|
||||
|
||||
if lft.kind != rgt.kind:
|
||||
return Failure("Kind mismatch")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
|
||||
return Failure("Not the same type")
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
|
||||
return ActionList([ReplaceVariable(rgt, lft)])
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
|
||||
raise NotImplementedError # Should have been caught by kind check above
|
||||
|
||||
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
|
||||
if is_concrete(rgt):
|
||||
return Failure("Not the same type")
|
||||
|
||||
return Failure("Type shape mismatch")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
|
||||
return ActionList([ReplaceVariable(lft, rgt)])
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
|
||||
return ActionList([ReplaceVariable(lft, rgt)])
|
||||
|
||||
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
|
||||
if occurs(lft, rgt):
|
||||
return Failure("One type occurs in the other")
|
||||
|
||||
return ActionList([ReplaceVariable(lft, rgt)])
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
|
||||
return Failure("Not the same type constructor")
|
||||
|
||||
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
|
||||
return Failure("Not the same type constructor")
|
||||
|
||||
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
|
||||
return unify(rgt, lft)
|
||||
|
||||
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
|
||||
con_res = unify(lft.constructor, rgt.constructor)
|
||||
if isinstance(con_res, Failure):
|
||||
return con_res
|
||||
|
||||
arg_res = unify(lft.argument, rgt.argument)
|
||||
if isinstance(arg_res, Failure):
|
||||
return arg_res
|
||||
|
||||
return ActionList(con_res + arg_res)
|
||||
|
||||
|
||||
|
||||
return Failure('Not implemented')
|
||||
@ -305,8 +305,5 @@ def testEntry() -> i32:
|
||||
```
|
||||
|
||||
```py
|
||||
if TYPE_NAME.startswith('tuple_') or TYPE_NAME.startswith('static_array_') or TYPE_NAME.startswith('dynamic_array_'):
|
||||
expect_type_error('Not the same type constructor')
|
||||
else:
|
||||
expect_type_error('Not the same type')
|
||||
```
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from phasm.type5.solver import Type5SolverException
|
||||
|
||||
from ..helpers import Suite
|
||||
|
||||
|
||||
@ -38,24 +40,22 @@ def test_call_post_defined():
|
||||
code_py = """
|
||||
@exported
|
||||
def testEntry() -> i32:
|
||||
return helper(10, 3)
|
||||
return helper(13)
|
||||
|
||||
def helper(left: i32, right: i32) -> i32:
|
||||
return left - right
|
||||
def helper(left: i32) -> i32:
|
||||
return left
|
||||
"""
|
||||
|
||||
result = Suite(code_py).run_code()
|
||||
|
||||
assert 7 == result.returned_value
|
||||
assert 13 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@pytest.mark.skip('FIXME: Type checking')
|
||||
def test_call_invalid_type():
|
||||
code_py = """
|
||||
def helper(left: i32) -> i32:
|
||||
return left()
|
||||
"""
|
||||
|
||||
result = Suite(code_py).run_code()
|
||||
|
||||
assert 7 == result.returned_value
|
||||
with pytest.raises(Type5SolverException, match=r'i32 ~ Callable\[i32\]'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@ -91,8 +91,7 @@ def testEntry() -> i32:
|
||||
return action(double, 13.0)
|
||||
"""
|
||||
|
||||
match = r'Callable\[i32, i32\] ~ Callable\[f32, [^]]+\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@ -109,8 +108,7 @@ def testEntry() -> i32:
|
||||
return action(double, 13)
|
||||
"""
|
||||
|
||||
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[f32, i32\], p_[0-9]+, [^]]+\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@ -127,14 +125,14 @@ def testEntry() -> f32:
|
||||
return action(double, 13)
|
||||
"""
|
||||
|
||||
with pytest.raises(Type5SolverException, match='f32 ~ i32'):
|
||||
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_function_with_wrong_return_type_pass():
|
||||
code_py = """
|
||||
def double(left: i32) -> f32:
|
||||
return convert(left) * 2
|
||||
return convert(left) * 2.0
|
||||
|
||||
def action(applicable: Callable[i32, i32], left: i32) -> i32:
|
||||
return applicable(left)
|
||||
@ -144,8 +142,7 @@ def testEntry() -> i32:
|
||||
return action(double, 13)
|
||||
"""
|
||||
|
||||
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[i32, f32\], p_[0-9]+, [^]]+\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@ -179,12 +176,12 @@ def testEntry() -> i32:
|
||||
return action(double, 13, 14)
|
||||
"""
|
||||
|
||||
match = r'Callable\[Callable\[i32, i32, i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]'
|
||||
match = r'Callable\[i32, i32\] ~ i32'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_too_many_args_use():
|
||||
def test_sof_too_many_args_use_0():
|
||||
code_py = """
|
||||
def thirteen() -> i32:
|
||||
return 13
|
||||
@ -197,12 +194,30 @@ def testEntry() -> i32:
|
||||
return action(thirteen, 13)
|
||||
"""
|
||||
|
||||
match = r'Callable\[i32\] ~ Callable\[i32, p_[0-9]+\]'
|
||||
match = r'\(\) ~ i32'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code(verbose=True)
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_too_many_args_pass():
|
||||
def test_sof_too_many_args_use_1():
|
||||
code_py = """
|
||||
def thirteen(x: i32) -> i32:
|
||||
return x
|
||||
|
||||
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
|
||||
return applicable(left, right)
|
||||
|
||||
@exported
|
||||
def testEntry() -> i32:
|
||||
return action(thirteen, 13, 26)
|
||||
"""
|
||||
|
||||
match = r'i32 ~ Callable\[i32, i32\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code(verbose=True)
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_too_many_args_pass_0():
|
||||
code_py = """
|
||||
def double(left: i32) -> i32:
|
||||
return left * 2
|
||||
@ -215,6 +230,24 @@ def testEntry() -> i32:
|
||||
return action(double, 13, 14)
|
||||
"""
|
||||
|
||||
match = r'Callable\[Callable\[i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]'
|
||||
match = r'\(\) ~ i32'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_too_many_args_pass_1():
|
||||
code_py = """
|
||||
def double(left: i32, right: i32) -> i32:
|
||||
return left * right
|
||||
|
||||
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
|
||||
return applicable(left)
|
||||
|
||||
@exported
|
||||
def testEntry() -> i32:
|
||||
return action(double, 13, 14)
|
||||
"""
|
||||
|
||||
match = r'i32 ~ Callable\[i32, i32\]'
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@ -168,12 +168,9 @@ def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
|
||||
return foldl(x, y, z)
|
||||
"""
|
||||
|
||||
match = {
|
||||
'i8': 'Type shape mismatch',
|
||||
'i8[3]': 'Kind mismatch',
|
||||
}
|
||||
match = 'Type shape mismatch'
|
||||
|
||||
with pytest.raises(Type5SolverException, match=match[in_typ]):
|
||||
with pytest.raises(Type5SolverException, match=match):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user