Compare commits

..

No commits in common. "71691d68e92991c8d9aa2be3652e9beba2734b2b" and "3d504e3d79d39ae62a89c89be5a1bd8304e2b099" have entirely different histories.

16 changed files with 346 additions and 448 deletions

View File

@ -38,6 +38,7 @@ def load(build: BuildBase[Any]) -> None:
build.register_type_class(Sized)
def wasm_dynamic_array_len(g: WasmGenerator, tv_map: dict[str, TypeExpr]) -> None:
print('tv_map', tv_map)
del tv_map
# The length is stored in the first 4 bytes
g.i32.load()

View File

@ -8,7 +8,6 @@ from ..type5.typeexpr import (
TypeApplication,
TypeConstructor,
TypeExpr,
TypeLevelNat,
TypeVariable,
)
from ..type5.typerouter import TypeRouter
@ -103,9 +102,6 @@ class TypeName(BuildTypeRouter[str]):
def when_tuple(self, tp_args: list[TypeExpr]) -> str:
return '(' + ', '.join(map(self, tp_args)) + ', )'
def when_type_level_nat(self, typ: TypeLevelNat) -> str:
return str(typ.value)
def when_variable(self, typ: TypeVariable) -> str:
return typ.name

View File

@ -67,7 +67,7 @@ def expression(inp: ourlang.Expression) -> str:
return str(inp.variable.name)
if isinstance(inp, ourlang.BinaryOp):
return f'{expression(inp.left)} {inp.operator.name} {expression(inp.right)}'
return f'{expression(inp.left)} {inp.operator.function.name} {expression(inp.right)}'
if isinstance(inp, ourlang.FunctionCall):
args = ', '.join(
@ -75,10 +75,10 @@ def expression(inp: ourlang.Expression) -> str:
for arg in inp.arguments
)
if isinstance(inp.function, ourlang.StructConstructor):
return f'{inp.function.struct_type5.name}({args})'
if isinstance(inp.function_instance.function, ourlang.StructConstructor):
return f'{inp.function_instance.function.struct_type5.name}({args})'
return f'{inp.function.name}({args})'
return f'{inp.function_instance.function.name}({args})'
if isinstance(inp, ourlang.FunctionReference):
return str(inp.function.name)

View File

@ -11,14 +11,7 @@ from .build.typerouter import BuildTypeRouter
from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types
from .type5.constrainedexpr import ConstrainedExpr
from .type5.typeexpr import (
AtomicType,
TypeApplication,
TypeExpr,
TypeVariable,
is_concrete,
replace_variable,
)
from .type5.typeexpr import AtomicType, TypeApplication, TypeExpr, is_concrete
from .wasm import (
WasmTypeFloat32,
WasmTypeFloat64,
@ -155,94 +148,32 @@ def expression_subscript_tuple(wgn: WasmGenerator, mod: ourlang.Module[WasmGener
expression(wgn, mod, inp.varref)
wgn.add_statement(el_type_info.wasm_load_func, f'offset={offset}')
def expression_subscript_operator(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Subscript) -> None:
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
ftp5 = mod.build.type_classes['Subscriptable'].operators['[]']
fn_args = mod.build.type5_is_function(ftp5)
assert fn_args is not None
t_a = fn_args[0]
assert isinstance(t_a, TypeApplication)
t = t_a.constructor
a = t_a.argument
assert isinstance(t, TypeVariable)
assert isinstance(a, TypeVariable)
assert isinstance(inp.varref.type5, TypeApplication)
t_expr = inp.varref.type5.constructor
a_expr = inp.varref.type5.argument
_expression_binary_operator_or_function_call(
wgn,
mod,
ourlang.BuiltinFunction('[]', ftp5),
{
t: t_expr,
a: a_expr,
},
[inp.varref, inp.index],
inp.type5,
)
def expression_binary_op(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.BinaryOp) -> None:
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
_expression_binary_operator_or_function_call(
wgn,
mod,
inp.operator,
inp.polytype_substitutions,
[inp.left, inp.right],
inp.type5,
)
expression_function_call(wgn, mod, _binary_op_to_function(inp))
def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.FunctionCall) -> None:
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
_expression_binary_operator_or_function_call(
wgn,
mod,
inp.function,
inp.polytype_substitutions,
inp.arguments,
inp.type5,
)
def _expression_binary_operator_or_function_call(
wgn: WasmGenerator,
mod: ourlang.Module[WasmGenerator],
function: ourlang.Function | ourlang.FunctionParam,
polytype_substitutions: dict[TypeVariable, TypeExpr],
arguments: list[ourlang.Expression],
ret_type5: TypeExpr,
) -> None:
for arg in arguments:
for arg in inp.arguments:
expression(wgn, mod, arg)
if isinstance(function, ourlang.BuiltinFunction):
ftp5 = function.type5
if isinstance(ftp5, ConstrainedExpr):
cexpr = ftp5
ftp5 = ftp5.expr
for tvar in cexpr.variables:
ftp5 = replace_variable(ftp5, tvar, polytype_substitutions[tvar])
assert _is_concrete(ftp5), TYPE5_ASSERTION_ERROR
if isinstance(inp.function_instance.function, ourlang.BuiltinFunction):
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
try:
method_type, method_router = mod.build.methods[function.name]
method_type, method_router = mod.build.methods[inp.function_instance.function.name]
except KeyError:
method_type, method_router = mod.build.operators[function.name]
method_type, method_router = mod.build.operators[inp.function_instance.function.name]
impl_lookup = method_router.get((ftp5, ))
assert impl_lookup is not None, (function.name, ftp5, )
impl_lookup = method_router.get((inp.function_instance.type5, ))
assert impl_lookup is not None, (inp.function_instance.function.name, inp.function_instance.type5, )
kwargs, impl = impl_lookup
impl(wgn, kwargs)
return
if isinstance(function, ourlang.FunctionParam):
fn_args = mod.build.type5_is_function(function.type5)
assert fn_args is not None, function.type5
if isinstance(inp.function_instance.function, ourlang.FunctionParam):
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
fn_args = mod.build.type5_is_function(inp.function_instance.type5)
assert fn_args is not None
params = [
type5(mod, x)
@ -251,15 +182,11 @@ def _expression_binary_operator_or_function_call(
result = params.pop()
wgn.add_statement('local.get', '${}'.format(function.name))
wgn.add_statement('local.get', '${}'.format(inp.function_instance.function.name))
wgn.call_indirect(params=params, result=result)
return
# TODO: Do similar subsitutions like we do for BuiltinFunction
# when we get user space polymorphic functions
# And then do similar lookup, and ensure we generate code for that variant
wgn.call(function.name)
wgn.call(inp.function_instance.function.name)
def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Expression) -> None:
"""
@ -351,7 +278,22 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourl
expression_subscript_tuple(wgn, mod, inp)
return
expression_subscript_operator(wgn, mod, inp)
inp_as_fc = ourlang.FunctionCall(
ourlang.FunctionInstance(
ourlang.BuiltinFunction('[]', mod.build.type_classes['Subscriptable'].operators['[]']),
inp.sourceref,
),
inp.sourceref,
)
inp_as_fc.arguments = [inp.varref, inp.index]
inp_as_fc.function_instance.type5 = mod.build.type5_make_function([
inp.varref.type5,
inp.index.type5,
inp.type5,
])
inp_as_fc.type5 = inp.type5
expression_function_call(wgn, mod, inp_as_fc)
return
if isinstance(inp, ourlang.AccessStructMember):
@ -379,11 +321,11 @@ def statement_return(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun
# Support tail calls
# https://github.com/WebAssembly/tail-call
# These help a lot with some functional programming techniques
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function is fun:
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function_instance.function is fun:
for arg in inp.value.arguments:
expression(wgn, mod, arg)
wgn.add_statement('return_call', '${}'.format(inp.value.function.name))
wgn.add_statement('return_call', '${}'.format(inp.value.function_instance.function.name))
return
expression(wgn, mod, inp.value)
@ -653,3 +595,14 @@ def _type5_struct_offset(
result += build.type5_alloc_size_member(memtyp)
raise RuntimeError('Member not found')
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
"""
For compilation purposes, a binary operator is just a function call.
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
"""
assert inp.sourceref is not None # TODO: sourceref required
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
call.arguments = [inp.left, inp.right]
return call

View File

@ -157,39 +157,52 @@ class BinaryOp(Expression):
"""
A binary operator expression within a statement
"""
__slots__ = ('operator', 'polytype_substitutions', 'left', 'right', )
__slots__ = ('operator', 'left', 'right', )
operator: Function | FunctionParam
polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]
operator: FunctionInstance
left: Expression
right: Expression
def __init__(self, operator: Function | FunctionParam, left: Expression, right: Expression, sourceref: SourceRef) -> None:
def __init__(self, operator: FunctionInstance, left: Expression, right: Expression, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.operator = operator
self.polytype_substitutions = {}
self.left = left
self.right = right
def __repr__(self) -> str:
return f'BinaryOp({repr(self.operator)}, {repr(self.left)}, {repr(self.right)})'
class FunctionInstance(Expression):
"""
When calling a polymorphic function with concrete arguments, we can generate
code for that specific instance of the function.
"""
__slots__ = ('function', )
function: Union['Function', 'FunctionParam']
def __init__(self, function: Union['Function', 'FunctionParam'], sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.function = function
class FunctionCall(Expression):
"""
A function call expression within a statement
"""
__slots__ = ('function', 'polytype_substitutions', 'arguments', )
__slots__ = ('function_instance', 'arguments', )
function: Function | FunctionParam
polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]
function_instance: FunctionInstance
# TODO: FunctionInstance is wrong - we should have
# substitutions: dict[TypeVariable, TypeExpr]
# And it should have the same variables as the polytype (ConstrainedExpr) for function
arguments: List[Expression]
def __init__(self, function: Function | FunctionParam, sourceref: SourceRef) -> None:
def __init__(self, function_instance: FunctionInstance, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.function = function
self.polytype_substitutions = {}
self.function_instance = function_instance
self.arguments = []
class FunctionReference(Expression):

View File

@ -18,6 +18,7 @@ from .ourlang import (
Expression,
Function,
FunctionCall,
FunctionInstance,
FunctionParam,
FunctionReference,
Module,
@ -399,7 +400,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'Operator {operator}')
return BinaryOp(
BuiltinFunction(operator, module.operators[operator]),
FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
srf(module, node),
@ -428,7 +429,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'Operator {operator}')
return BinaryOp(
BuiltinFunction(operator, module.operators[operator]),
FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
srf(module, node),
@ -505,7 +506,7 @@ class OurVisitor[G]:
func = module.functions[node.func.id]
result = FunctionCall(func, sourceref=srf(module, node))
result = FunctionCall(FunctionInstance(func, srf(module, node)), sourceref=srf(module, node))
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr in node.args

View File

@ -42,8 +42,9 @@ class ConstrainedExpr:
def instantiate_constrained(
constrainedexpr: ConstrainedExpr,
known_map: dict[TypeVariable, TypeVariable],
make_variable: Callable[[KindExpr, str], TypeVariable],
) -> tuple[ConstrainedExpr, dict[TypeVariable, TypeVariable]]:
) -> ConstrainedExpr:
"""
Instantiates a type expression and its constraints
"""
@ -60,4 +61,4 @@ def instantiate_constrained(
x.instantiate(known_map)
for x in constrainedexpr.constraints
)
return ConstrainedExpr(constrainedexpr.variables, expr, constraints), known_map
return ConstrainedExpr(constrainedexpr.variables, expr, constraints)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import dataclasses
from typing import Any, Callable, Iterable, Protocol, Sequence, TypeAlias
from typing import Any, Callable, Iterable, Protocol, Sequence
from ..build.base import BuildBase
from ..ourlang import SourceRef
@ -9,16 +9,13 @@ from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt6
from .kindexpr import KindExpr, Star
from .record import Record
from .typeexpr import (
AtomicType,
TypeApplication,
TypeConstructor,
TypeExpr,
TypeLevelNat,
TypeVariable,
is_concrete,
occurs,
replace_variable,
)
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
class ExpressionProtocol(Protocol):
@ -31,95 +28,50 @@ class ExpressionProtocol(Protocol):
The type to update
"""
PolytypeSubsituteMap: TypeAlias = dict[TypeVariable, TypeExpr]
class Context:
__slots__ = ("build", "placeholder_update", "ptst_update", )
__slots__ = ("build", "placeholder_update", )
build: BuildBase[Any]
placeholder_update: dict[TypeVariable, ExpressionProtocol | None]
ptst_update: dict[TypeVariable, tuple[PolytypeSubsituteMap, TypeVariable]]
def __init__(self, build: BuildBase[Any]) -> None:
self.build = build
self.placeholder_update = {}
self.ptst_update = {}
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star(), prefix: str = 'p') -> TypeVariable:
res = TypeVariable(kind, f"{prefix}_{len(self.placeholder_update)}")
self.placeholder_update[res] = arg
return res
def register_polytype_subsitutes(self, tvar: TypeVariable, arg: PolytypeSubsituteMap, orig_var: TypeVariable) -> None:
"""
When `tvar` gets subsituted, also set the result in arg with orig_var as key
e.g.
(-) :: Callable[a, a, a]
def foo() -> u32:
return 2 - 1
During typing, we instantiate a into a_3, and get the following constraints:
- u8 ~ p_1
- u8 ~ p_2
- Exists NatNum a_3
- Callable[a_3, a_3, a_3] ~ Callable[p_1, p_2, p_0]
- u8 ~ p_0
When we resolve a_3, then on the call to `-`, we should note that a_3 got resolved
to u32. But we need to use `a` as key, since that's what's used on the definition
"""
assert tvar in self.placeholder_update
assert tvar not in self.ptst_update
self.ptst_update[tvar] = (arg, orig_var)
@dataclasses.dataclass
class Failure:
"""
Both types are already different - cannot be unified.
"""
msg: str
@dataclasses.dataclass
class ReplaceVariable:
var: TypeVariable
typ: TypeExpr
@dataclasses.dataclass
class CheckResult:
# TODO: Refactor this, don't think we use most of the variants
_: dataclasses.KW_ONLY
done: bool = True
replace: ReplaceVariable | None = None
actions: ActionList = dataclasses.field(default_factory=ActionList)
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
failures: list[Failure] = dataclasses.field(default_factory=list)
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
if not self.done and not self.replace and not self.new_constraints and not self.failures:
if not self.done and not self.actions and not self.new_constraints and not self.failures:
return '(skip for now)'
if self.done and not self.replace and not self.new_constraints and not self.failures:
if self.done and not self.actions and not self.new_constraints and not self.failures:
return '(ok)'
if self.done and self.replace and not self.new_constraints and not self.failures:
return f'{{{self.replace.var.name} := {type_namer(self.replace.typ)}}}'
if self.done and self.actions and not self.new_constraints and not self.failures:
return self.actions.to_str(type_namer)
if self.done and not self.replace and self.new_constraints and not self.failures:
if self.done and not self.actions and self.new_constraints and not self.failures:
return f'(got {len(self.new_constraints)} new constraints)'
if self.done and not self.replace and not self.new_constraints and self.failures:
if self.done and not self.actions and not self.new_constraints and self.failures:
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
return f'{self.done} {self.replace} {self.new_constraints} {self.failures}'
return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}'
def skip_for_now() -> CheckResult:
return CheckResult(done=False)
def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult:
return CheckResult(replace=ReplaceVariable(var, typ))
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
return CheckResult(new_constraints=list(lst))
@ -142,8 +94,12 @@ class ConstraintBase:
def check(self) -> CheckResult:
raise NotImplementedError(self)
def complexity(self) -> int:
raise NotImplementedError
def apply(self, action: Action) -> None:
if isinstance(action, ReplaceVariable):
self.replace_variable(action.var, action.typ)
return
raise NotImplementedError(action)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
pass
@ -186,9 +142,6 @@ class FromLiteralInteger(ConstraintBase):
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
@ -222,11 +175,8 @@ class FromLiteralFloat(ConstraintBase):
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralFloat {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class FromLiteralBytes(ConstraintBase):
__slots__ = ('type5', 'literal', )
@ -253,125 +203,32 @@ class FromLiteralBytes(ConstraintBase):
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralBytes {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class UnifyTypesConstraint(ConstraintBase):
__slots__ = ("lft", "rgt", "prefix", )
__slots__ = ("lft", "rgt",)
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr, prefix: str | None = None) -> None:
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr) -> None:
super().__init__(ctx, sourceref)
self.lft = lft
self.rgt = rgt
self.prefix = prefix
def check(self) -> CheckResult:
lft = self.lft
rgt = self.rgt
result = unify(self.lft, self.rgt)
if lft == self.rgt:
return ok()
if isinstance(result, Failure):
return CheckResult(failures=[result])
if lft.kind != rgt.kind:
return fail("Kind mismatch")
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
return fail("Not the same type")
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
return replace(rgt, lft)
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
return fail("Not the same type" if is_concrete(rgt) else "Type shape mismatch")
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
if occurs(lft, rgt):
return fail("One type occurs in the other")
return replace(lft, rgt)
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
return replace(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
return fail("Not the same type constructor")
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
return fail("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
return fail("Not the same type" if is_concrete(lft) else "Type shape mismatch")
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
if occurs(rgt, lft):
return fail("One type occurs in the other")
return replace(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
return fail("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
## USABILITY HACK
## Often, we have two type applications in the same go
## If so, resolve it in a single step
## (Helps with debugging function unification)
## This *should* not affect the actual type unification
## It's just one less call to UnifyTypesConstraint.check
if isinstance(lft.constructor, TypeApplication) and isinstance(rgt.constructor, TypeApplication):
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.constructor, rgt.constructor.constructor),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.argument, rgt.constructor.argument),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
])
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor, rgt.constructor),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
])
raise NotImplementedError(lft, rgt)
return CheckResult(actions=result)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.lft = replace_variable(self.lft, var, typ)
self.rgt = replace_variable(self.rgt, var, typ)
def complexity(self) -> int:
return complexity(self.lft) + complexity(self.rgt)
def __str__(self) -> str:
prefix = f'{self.prefix} :: ' if self.prefix else ''
return f"{prefix}{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
class CanBeSubscriptedConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
@ -433,9 +290,6 @@ class CanBeSubscriptedConstraint(ConstraintBase):
self.container_type5 = replace_variable(self.container_type5, var, typ)
self.index_type5 = replace_variable(self.index_type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + complexity(self.container_type5) + complexity(self.index_type5)
def __str__(self) -> str:
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
@ -479,9 +333,6 @@ class CanAccessStructMemberConstraint(ConstraintBase):
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + complexity(self.struct_type5)
def __str__(self) -> str:
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
member_dict = dict(st_args or [])
@ -553,9 +404,6 @@ class FromTupleConstraint(ConstraintBase):
for x in self.member_type5_list
]
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + sum(complexity(x) for x in self.member_type5_list)
def __str__(self) -> str:
args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list)
return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )'
@ -602,24 +450,6 @@ class TypeClassInstanceExistsConstraint(ConstraintBase):
for x in self.arg_list
]
def complexity(self) -> int:
return 100 + sum(complexity(x) for x in self.arg_list)
def __str__(self) -> str:
args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list)
return f'Exists {self.typeclass} {args}'
def complexity(expr: TypeExpr) -> int:
if isinstance(expr, AtomicType | TypeLevelNat):
return 1
if isinstance(expr, TypeConstructor):
return 2
if isinstance(expr, TypeVariable):
return 5
if isinstance(expr, TypeApplication):
return complexity(expr.constructor) + complexity(expr.argument)
raise NotImplementedError(expr)

View File

@ -16,7 +16,7 @@ from .constraints import (
UnifyTypesConstraint,
)
from .kindexpr import KindExpr
from .typeexpr import TypeApplication, TypeExpr, TypeVariable, is_concrete
from .typeexpr import TypeApplication, TypeVariable, instantiate
ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -90,41 +90,14 @@ def expression_constant(ctx: Context, inp: ourlang.Constant, phft: TypeVariable)
raise NotImplementedError(inp)
def expression_variable_reference(ctx: Context, inp: ourlang.VariableReference, phft: TypeVariable) -> ConstraintGenerator:
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft, prefix=inp.variable.name)
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft)
def expression_binary_operator(ctx: Context, inp: ourlang.BinaryOp, phft: TypeVariable) -> ConstraintGenerator:
yield from _expression_binary_operator_or_function_call(
ctx,
inp.operator,
inp.polytype_substitutions,
[inp.left, inp.right],
inp.sourceref,
f'({inp.operator.name})',
phft,
)
yield from expression_function_call(ctx, _binary_op_to_function(inp), phft)
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: TypeVariable) -> ConstraintGenerator:
yield from _expression_binary_operator_or_function_call(
ctx,
inp.function,
inp.polytype_substitutions,
inp.arguments,
inp.sourceref,
inp.function.name,
phft,
)
def _expression_binary_operator_or_function_call(
ctx: Context,
function: ourlang.Function | ourlang.FunctionParam,
polytype_substitutions: dict[TypeVariable, TypeExpr],
arguments: list[ourlang.Expression],
sourceref: ourlang.SourceRef,
function_name: str,
phft: TypeVariable,
) -> ConstraintGenerator:
arg_typ_list = []
for arg in arguments:
for arg in inp.arguments:
arg_tv = ctx.make_placeholder(arg)
yield from expression(ctx, arg, arg_tv)
arg_typ_list.append(arg_tv)
@ -132,28 +105,34 @@ def _expression_binary_operator_or_function_call(
def make_placeholder(x: KindExpr, p: str) -> TypeVariable:
return ctx.make_placeholder(kind=x, prefix=p)
ftp5 = function.type5
ftp5 = inp.function_instance.function.type5
assert ftp5 is not None
if isinstance(ftp5, ConstrainedExpr):
ftp5, phft_lookup = instantiate_constrained(ftp5, make_placeholder)
for orig_tvar, tvar in phft_lookup.items():
ctx.register_polytype_subsitutes(tvar, polytype_substitutions, orig_tvar)
ftp5 = instantiate_constrained(ftp5, {}, make_placeholder)
for type_constraint in ftp5.constraints:
if isinstance(type_constraint, TypeClassConstraint):
yield TypeClassInstanceExistsConstraint(ctx, sourceref, type_constraint.cls.name, type_constraint.variables)
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, type_constraint.cls.name, type_constraint.variables)
continue
raise NotImplementedError(type_constraint)
ftp5 = ftp5.expr
else:
assert is_concrete(ftp5)
ftp5 = instantiate(ftp5, {})
# We need an extra placeholder so that the inp.function_instance gets updated
phft2 = ctx.make_placeholder(inp.function_instance)
yield UnifyTypesConstraint(
ctx,
inp.sourceref,
ftp5,
phft2,
)
expr_type = ctx.build.type5_make_function(arg_typ_list + [phft])
yield UnifyTypesConstraint(ctx, sourceref, ftp5, expr_type, prefix=function_name)
yield UnifyTypesConstraint(ctx, inp.sourceref, phft2, expr_type)
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: TypeVariable) -> ConstraintGenerator:
assert inp.function.type5 is not None # Todo: Make not nullable
@ -162,7 +141,7 @@ def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference,
if isinstance(ftp5, ConstrainedExpr):
ftp5 = ftp5.expr
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft, prefix=inp.function.name)
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft)
def expression_tuple_instantiation(ctx: Context, inp: ourlang.TupleInstantiation, phft: TypeVariable) -> ConstraintGenerator:
arg_typ_list = []
@ -241,7 +220,7 @@ def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement
type5 = fun.type5.expr if isinstance(fun.type5, ConstrainedExpr) else fun.type5
yield from expression(ctx, inp.value, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft, prefix=f'{fun.name} returns')
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft)
def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf) -> ConstraintGenerator:
test_phft = ctx.make_placeholder(inp.test)
@ -288,3 +267,14 @@ def module(ctx: Context, inp: ourlang.Module[Any]) -> ConstraintGenerator:
yield from function(ctx, func)
# TODO: Generalize?
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
"""
For typing purposes, a binary operator is just a function call.
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
"""
assert inp.sourceref is not None # TODO: sourceref required
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
call.arguments = [inp.left, inp.right]
return call

View File

@ -3,7 +3,8 @@ from typing import Any
from ..ourlang import Module
from .constraints import ConstraintBase, Context
from .fromast import phasm_type5_generate_constraints
from .typeexpr import TypeExpr, TypeVariable, is_concrete, replace_variable
from .typeexpr import TypeExpr, TypeVariable, replace_variable
from .unify import ReplaceVariable
MAX_RESTACK_COUNT = 100
@ -30,7 +31,8 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
print("Validating")
new_constraint_list: list[ConstraintBase] = []
for constraint in sorted(constraint_list, key=lambda x: x.complexity()):
while constraint_list:
constraint = constraint_list.pop(0)
result = constraint.check()
if verbose:
@ -42,15 +44,29 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
# Means it checks out and we don't need do anything
continue
if result.replace is not None:
action_var = result.replace.var
assert action_var not in placeholder_types # When does this happen?
while result.actions:
action = result.actions.pop(0)
action_typ: TypeExpr = result.replace.typ
assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen?
if isinstance(action, ReplaceVariable):
action_var: TypeExpr = action.var
while isinstance(action_var, TypeVariable) and action_var in placeholder_types:
# TODO: Does this still happen?
action_var = placeholder_types[action_var]
assert action_var != action_typ # When does this happen?
action_typ: TypeExpr = action.typ
while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types:
# TODO: Does this still happen?
action_typ = placeholder_types[action_typ]
# print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ))
if action_var == action_typ:
continue
if not isinstance(action_var, TypeVariable) and isinstance(action_typ, TypeVariable):
action_typ, action_var = action_var, action_typ
if isinstance(action_var, TypeVariable):
# Ensure all existing found types are updated
placeholder_types = {
k: replace_variable(v, action_var, action_typ)
@ -69,14 +85,20 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
if verbose and old_str != new_str:
print(f"{oth_const.sourceref!s} => - {old_str!s}")
print(f"{oth_const.sourceref!s} => + {new_str!s}")
continue
error_list.append((str(constraint.sourceref), str(constraint), "Not the same type", ))
if verbose:
print(f"{constraint.sourceref!s} => ERR: Conflict in applying {action.to_str(inp.build.type5_name)}")
continue
# Action of unsupported type
raise NotImplementedError(action)
for failure in result.failures:
error_list.append((str(constraint.sourceref), str(constraint), failure.msg, ))
new_constraint_list.extend(result.new_constraints)
if verbose:
for new_const in result.new_constraints:
print(f"{oth_const.sourceref!s} => + {new_const!s}")
if result.done:
continue
@ -102,11 +124,8 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
if expression is None:
continue
resolved_type5 = placeholder_types[placeholder]
assert is_concrete(resolved_type5) # When does this happen?
expression.type5 = resolved_type5
new_type5 = placeholder_types[placeholder]
while isinstance(new_type5, TypeVariable):
new_type5 = placeholder_types[new_type5]
for placeholder, (ptst_map, orig_tvar) in ctx.ptst_update.items():
resolved_type5 = placeholder_types[placeholder]
assert is_concrete(resolved_type5) # When does this happen?
ptst_map[orig_tvar] = resolved_type5
expression.type5 = new_type5

View File

@ -4,7 +4,6 @@ from .typeexpr import (
TypeApplication,
TypeConstructor,
TypeExpr,
TypeLevelNat,
TypeVariable,
)
@ -22,9 +21,6 @@ class TypeRouter[T]:
def when_record(self, typ: Record) -> T:
raise NotImplementedError(typ)
def when_type_level_nat(self, typ: TypeLevelNat) -> T:
raise NotImplementedError(typ)
def when_variable(self, typ: TypeVariable) -> T:
raise NotImplementedError(typ)
@ -41,9 +37,6 @@ class TypeRouter[T]:
if isinstance(typ, TypeConstructor):
return self.when_constructor(typ)
if isinstance(typ, TypeLevelNat):
return self.when_type_level_nat(typ)
if isinstance(typ, TypeVariable):
return self.when_variable(typ)

128
phasm/type5/unify.py Normal file
View File

@ -0,0 +1,128 @@
from dataclasses import dataclass
from typing import Callable
from .typeexpr import (
AtomicType,
TypeApplication,
TypeConstructor,
TypeExpr,
TypeVariable,
is_concrete,
occurs,
)
@dataclass
class Failure:
"""
Both types are already different - cannot be unified.
"""
msg: str
@dataclass
class Action:
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
raise NotImplementedError
class ActionList(list[Action]):
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return '{' + ', '.join((x.to_str(type_namer) for x in self)) + '}'
UnifyResult = Failure | ActionList
@dataclass
class ReplaceVariable(Action):
var: TypeVariable
typ: TypeExpr
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return f'{self.var.name} := {type_namer(self.typ)}'
def unify(lft: TypeExpr, rgt: TypeExpr) -> UnifyResult:
"""
Be warned: This only matches type variables with other variables or types
- it does not apply substituions nor does it validate if the matching
pairs are correct.
TODO: Remove this. It should be part of UnifyTypesConstraint
and should just generate new constraints for applications.
"""
if lft == rgt:
return ActionList()
if lft.kind != rgt.kind:
return Failure("Kind mismatch")
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
return Failure("Not the same type")
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
return ActionList([ReplaceVariable(rgt, lft)])
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
if is_concrete(rgt):
return Failure("Not the same type")
return Failure("Type shape mismatch")
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
if occurs(lft, rgt):
return Failure("One type occurs in the other")
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
return unify(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
return Failure("Not the same type constructor")
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
return Failure("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
con_res = unify(lft.constructor, rgt.constructor)
if isinstance(con_res, Failure):
return con_res
arg_res = unify(lft.argument, rgt.argument)
if isinstance(arg_res, Failure):
return arg_res
return ActionList(con_res + arg_res)
return Failure('Not implemented')

View File

@ -305,5 +305,8 @@ def testEntry() -> i32:
```
```py
expect_type_error('Not the same type')
if TYPE_NAME.startswith('tuple_') or TYPE_NAME.startswith('static_array_') or TYPE_NAME.startswith('dynamic_array_'):
expect_type_error('Not the same type constructor')
else:
expect_type_error('Not the same type')
```

View File

@ -1,7 +1,5 @@
import pytest
from phasm.type5.solver import Type5SolverException
from ..helpers import Suite
@ -40,22 +38,24 @@ def test_call_post_defined():
code_py = """
@exported
def testEntry() -> i32:
return helper(13)
return helper(10, 3)
def helper(left: i32) -> i32:
return left
def helper(left: i32, right: i32) -> i32:
return left - right
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
assert 7 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('FIXME: Type checking')
def test_call_invalid_type():
code_py = """
def helper(left: i32) -> i32:
return left()
"""
with pytest.raises(Type5SolverException, match=r'i32 ~ Callable\[i32\]'):
Suite(code_py).run_code()
result = Suite(code_py).run_code()
assert 7 == result.returned_value

View File

@ -91,7 +91,8 @@ def testEntry() -> i32:
return action(double, 13.0)
"""
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
match = r'Callable\[i32, i32\] ~ Callable\[f32, [^]]+\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test
@ -108,7 +109,8 @@ def testEntry() -> i32:
return action(double, 13)
"""
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[f32, i32\], p_[0-9]+, [^]]+\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test
@ -125,14 +127,14 @@ def testEntry() -> f32:
return action(double, 13)
"""
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
with pytest.raises(Type5SolverException, match='f32 ~ i32'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_function_with_wrong_return_type_pass():
code_py = """
def double(left: i32) -> f32:
return convert(left) * 2.0
return convert(left) * 2
def action(applicable: Callable[i32, i32], left: i32) -> i32:
return applicable(left)
@ -142,7 +144,8 @@ def testEntry() -> i32:
return action(double, 13)
"""
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[i32, f32\], p_[0-9]+, [^]]+\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test
@ -176,12 +179,12 @@ def testEntry() -> i32:
return action(double, 13, 14)
"""
match = r'Callable\[i32, i32\] ~ i32'
match = r'Callable\[Callable\[i32, i32, i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_too_many_args_use_0():
def test_sof_too_many_args_use():
code_py = """
def thirteen() -> i32:
return 13
@ -194,30 +197,12 @@ def testEntry() -> i32:
return action(thirteen, 13)
"""
match = r'\(\) ~ i32'
match = r'Callable\[i32\] ~ Callable\[i32, p_[0-9]+\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code(verbose=True)
@pytest.mark.integration_test
def test_sof_too_many_args_use_1():
code_py = """
def thirteen(x: i32) -> i32:
return x
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(thirteen, 13, 26)
"""
match = r'i32 ~ Callable\[i32, i32\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code(verbose=True)
@pytest.mark.integration_test
def test_sof_too_many_args_pass_0():
def test_sof_too_many_args_pass():
code_py = """
def double(left: i32) -> i32:
return left * 2
@ -230,24 +215,6 @@ def testEntry() -> i32:
return action(double, 13, 14)
"""
match = r'\(\) ~ i32'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_too_many_args_pass_1():
code_py = """
def double(left: i32, right: i32) -> i32:
return left * right
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13, 14)
"""
match = r'i32 ~ Callable\[i32, i32\]'
match = r'Callable\[Callable\[i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()

View File

@ -168,9 +168,12 @@ def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
return foldl(x, y, z)
"""
match = 'Type shape mismatch'
match = {
'i8': 'Type shape mismatch',
'i8[3]': 'Kind mismatch',
}
with pytest.raises(Type5SolverException, match=match):
with pytest.raises(Type5SolverException, match=match[in_typ]):
Suite(code_py).run_code()
@pytest.mark.integration_test