Compare commits

...

3 Commits

Author SHA1 Message Date
Johan B.W. de Vries
3d6d279408 Implements the IO type constructor and Monad type class 2025-08-24 16:09:23 +02:00
71691d68e9 Merge pull request 'Removes the weird second step unify' (#9) from rework-unify-to-be-a-normal-constraint into master
Reviewed-on: #9
2025-08-24 14:07:36 +00:00
Johan B.W. de Vries
7df9d5af12 Removes the weird second step unify
It is now part of the normal constraints. Added a special
workaround for functions, since otherwise the output is a
bit redundant and quite confusing.

Also, constraints are now processed in order of complexity.
This does not affect type safety. It uses a bit more CPU.
But it makes the output that much easier to read.

Also, removes the weird FunctionInstance hack. Instead,
the more industry standard way of annotation the types
on the function call is used. As always, this requires some
hackyness for Subscriptable.

Also, adds a few comments to the type unification to help
with debugging.

Also, prints out the new constraints that are received.
2025-08-24 16:06:42 +02:00
21 changed files with 641 additions and 354 deletions

View File

@ -45,6 +45,7 @@ class BuildBase[G]:
__slots__ = (
'dynamic_array_type5_constructor',
'function_type5_constructor',
'io_type5_constructor',
'static_array_type5_constructor',
'tuple_type5_constructor_map',
@ -83,6 +84,18 @@ class BuildBase[G]:
See type5_make_function and type5_is_function.
"""
io_type5_constructor: type5typeexpr.TypeConstructor
"""
Constructor for IO.
An IO function is a function that can have side effects.
It can do input or output. Other functions cannot have
side effects, and can only return a value based on the
input given.
See type5_make_io and type5_is_io.
"""
static_array_type5_constructor: type5typeexpr.TypeConstructor
"""
Constructor for arrays of compiled time determined length.
@ -207,6 +220,7 @@ class BuildBase[G]:
self.dynamic_array_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> S, name="dynamic_array")
self.function_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> (S >> S), name="function")
self.io_type5_constructor = type5typeexpr.TypeConstructor(kind=S >> S, name="IO")
self.static_array_type5_constructor = type5typeexpr.TypeConstructor(kind=N >> (S >> S), name='static_array')
self.tuple_type5_constructor_map = {}
@ -344,6 +358,19 @@ class BuildBase[G]:
return my_args + more_args
def type5_make_io(self, arg: type5typeexpr.TypeExpr) -> type5typeexpr.TypeApplication:
return type5typeexpr.TypeApplication(
constructor=self.io_type5_constructor,
argument=arg
)
def type5_is_io(self, typeexpr: type5typeexpr.TypeExpr | type5constrainedexpr.ConstrainedExpr) -> type5typeexpr.TypeExpr | None:
if not isinstance(typeexpr, type5typeexpr.TypeApplication):
return None
if typeexpr.constructor != self.io_type5_constructor:
return None
return typeexpr.argument
def type5_make_tuple(self, args: Sequence[type5typeexpr.TypeExpr]) -> type5typeexpr.TypeApplication:
if not args:
raise TypeError("Tuples must at least one field")

View File

@ -22,6 +22,7 @@ from .typeclasses import (
fractional,
integral,
intnum,
monad,
natnum,
ord,
promotable,
@ -68,6 +69,7 @@ class BuildDefault(BuildBase[Generator]):
integral,
foldable, subscriptable,
sized,
monad,
]
for tc in tc_list:

View File

@ -0,0 +1,26 @@
"""
The Monad type class is defined for type constructors that cause one thing to happen /after/ another.
"""
from __future__ import annotations
from typing import Any
from ...type5.constrainedexpr import ConstrainedExpr
from ...type5.kindexpr import Star
from ...type5.typeexpr import TypeVariable
from ...typeclass import TypeClass, TypeClassConstraint
from ...wasmgenerator import Generator as WasmGenerator
from ..base import BuildBase
def load(build: BuildBase[Any]) -> None:
a = TypeVariable(kind=Star(), name='a')
Monad = TypeClass('Monad', (a, ), methods={}, operators={})
build.register_type_class(Monad)
def wasm(build: BuildBase[WasmGenerator]) -> None:
Monad = build.type_classes['Monad']
build.instance_type_class(Monad, build.io_type5_constructor)

View File

@ -38,7 +38,6 @@ def load(build: BuildBase[Any]) -> None:
build.register_type_class(Sized)
def wasm_dynamic_array_len(g: WasmGenerator, tv_map: dict[str, TypeExpr]) -> None:
print('tv_map', tv_map)
del tv_map
# The length is stored in the first 4 bytes
g.i32.load()

View File

@ -8,6 +8,7 @@ from ..type5.typeexpr import (
TypeApplication,
TypeConstructor,
TypeExpr,
TypeLevelNat,
TypeVariable,
)
from ..type5.typerouter import TypeRouter
@ -35,6 +36,10 @@ class BuildTypeRouter[T](TypeRouter[T]):
if fn_args is not None:
return self.when_function(fn_args)
io_args = self.build.type5_is_io(typ)
if io_args is not None:
return self.when_io(io_args)
sa_args = self.build.type5_is_static_array(typ)
if sa_args is not None:
sa_len, sa_typ = sa_args
@ -58,6 +63,9 @@ class BuildTypeRouter[T](TypeRouter[T]):
def when_function(self, fn_args: list[TypeExpr]) -> T:
raise NotImplementedError
def when_io(self, io_arg: TypeExpr) -> T:
raise NotImplementedError
def when_struct(self, typ: Record) -> T:
raise NotImplementedError
@ -93,6 +101,9 @@ class TypeName(BuildTypeRouter[str]):
def when_function(self, fn_args: list[TypeExpr]) -> str:
return 'Callable[' + ', '.join(map(self, fn_args)) + ']'
def when_io(self, io_arg: TypeExpr) -> str:
return 'IO[' + self(io_arg) + ']'
def when_static_array(self, sa_len: int, sa_typ: TypeExpr) -> str:
return f'{self(sa_typ)}[{sa_len}]'
@ -102,6 +113,9 @@ class TypeName(BuildTypeRouter[str]):
def when_tuple(self, tp_args: list[TypeExpr]) -> str:
return '(' + ', '.join(map(self, tp_args)) + ', )'
def when_type_level_nat(self, typ: TypeLevelNat) -> str:
return str(typ.value)
def when_variable(self, typ: TypeVariable) -> str:
return typ.name

View File

@ -67,7 +67,7 @@ def expression(inp: ourlang.Expression) -> str:
return str(inp.variable.name)
if isinstance(inp, ourlang.BinaryOp):
return f'{expression(inp.left)} {inp.operator.function.name} {expression(inp.right)}'
return f'{expression(inp.left)} {inp.operator.name} {expression(inp.right)}'
if isinstance(inp, ourlang.FunctionCall):
args = ', '.join(
@ -75,10 +75,10 @@ def expression(inp: ourlang.Expression) -> str:
for arg in inp.arguments
)
if isinstance(inp.function_instance.function, ourlang.StructConstructor):
return f'{inp.function_instance.function.struct_type5.name}({args})'
if isinstance(inp.function, ourlang.StructConstructor):
return f'{inp.function.struct_type5.name}({args})'
return f'{inp.function_instance.function.name}({args})'
return f'{inp.function.name}({args})'
if isinstance(inp, ourlang.FunctionReference):
return str(inp.function.name)
@ -124,6 +124,10 @@ def statement(inp: ourlang.Statement) -> Statements:
yield ''
return
if isinstance(inp, ourlang.StatementCall):
yield expression(inp.call)
return
raise NotImplementedError(statement, inp)
def function(mod: ourlang.Module[Any], inp: ourlang.Function) -> str:

View File

@ -11,7 +11,14 @@ from .build.typerouter import BuildTypeRouter
from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types
from .type5.constrainedexpr import ConstrainedExpr
from .type5.typeexpr import AtomicType, TypeApplication, TypeExpr, is_concrete
from .type5.typeexpr import (
AtomicType,
TypeApplication,
TypeExpr,
TypeVariable,
is_concrete,
replace_variable,
)
from .wasm import (
WasmTypeFloat32,
WasmTypeFloat64,
@ -36,6 +43,11 @@ def type5(mod: ourlang.Module[WasmGenerator], inp: TypeExpr) -> wasm.WasmType:
Types are used for example in WebAssembly function parameters
and return types.
"""
io_arg = mod.build.type5_is_io(inp)
if io_arg is not None:
# IO is type constructor that only exists on the typing layer
inp = io_arg
typ_info = mod.build.type_info_map.get(inp.name)
if typ_info is None:
typ_info = mod.build.type_info_constructed
@ -148,32 +160,94 @@ def expression_subscript_tuple(wgn: WasmGenerator, mod: ourlang.Module[WasmGener
expression(wgn, mod, inp.varref)
wgn.add_statement(el_type_info.wasm_load_func, f'offset={offset}')
def expression_subscript_operator(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Subscript) -> None:
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
ftp5 = mod.build.type_classes['Subscriptable'].operators['[]']
fn_args = mod.build.type5_is_function(ftp5)
assert fn_args is not None
t_a = fn_args[0]
assert isinstance(t_a, TypeApplication)
t = t_a.constructor
a = t_a.argument
assert isinstance(t, TypeVariable)
assert isinstance(a, TypeVariable)
assert isinstance(inp.varref.type5, TypeApplication)
t_expr = inp.varref.type5.constructor
a_expr = inp.varref.type5.argument
_expression_binary_operator_or_function_call(
wgn,
mod,
ourlang.BuiltinFunction('[]', ftp5),
{
t: t_expr,
a: a_expr,
},
[inp.varref, inp.index],
inp.type5,
)
def expression_binary_op(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.BinaryOp) -> None:
expression_function_call(wgn, mod, _binary_op_to_function(inp))
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
_expression_binary_operator_or_function_call(
wgn,
mod,
inp.operator,
inp.polytype_substitutions,
[inp.left, inp.right],
inp.type5,
)
def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.FunctionCall) -> None:
for arg in inp.arguments:
assert _is_concrete(inp.type5), TYPE5_ASSERTION_ERROR
_expression_binary_operator_or_function_call(
wgn,
mod,
inp.function,
inp.polytype_substitutions,
inp.arguments,
inp.type5,
)
def _expression_binary_operator_or_function_call(
wgn: WasmGenerator,
mod: ourlang.Module[WasmGenerator],
function: ourlang.Function | ourlang.FunctionParam,
polytype_substitutions: dict[TypeVariable, TypeExpr],
arguments: list[ourlang.Expression],
ret_type5: TypeExpr,
) -> None:
for arg in arguments:
expression(wgn, mod, arg)
if isinstance(inp.function_instance.function, ourlang.BuiltinFunction):
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
if isinstance(function, ourlang.BuiltinFunction):
ftp5 = function.type5
if isinstance(ftp5, ConstrainedExpr):
cexpr = ftp5
ftp5 = ftp5.expr
for tvar in cexpr.variables:
ftp5 = replace_variable(ftp5, tvar, polytype_substitutions[tvar])
assert _is_concrete(ftp5), TYPE5_ASSERTION_ERROR
try:
method_type, method_router = mod.build.methods[inp.function_instance.function.name]
method_type, method_router = mod.build.methods[function.name]
except KeyError:
method_type, method_router = mod.build.operators[inp.function_instance.function.name]
method_type, method_router = mod.build.operators[function.name]
impl_lookup = method_router.get((inp.function_instance.type5, ))
assert impl_lookup is not None, (inp.function_instance.function.name, inp.function_instance.type5, )
impl_lookup = method_router.get((ftp5, ))
assert impl_lookup is not None, (function.name, ftp5, )
kwargs, impl = impl_lookup
impl(wgn, kwargs)
return
if isinstance(inp.function_instance.function, ourlang.FunctionParam):
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
fn_args = mod.build.type5_is_function(inp.function_instance.type5)
assert fn_args is not None
if isinstance(function, ourlang.FunctionParam):
fn_args = mod.build.type5_is_function(function.type5)
assert fn_args is not None, function.type5
params = [
type5(mod, x)
@ -182,11 +256,15 @@ def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerat
result = params.pop()
wgn.add_statement('local.get', '${}'.format(inp.function_instance.function.name))
wgn.add_statement('local.get', '${}'.format(function.name))
wgn.call_indirect(params=params, result=result)
return
wgn.call(inp.function_instance.function.name)
# TODO: Do similar subsitutions like we do for BuiltinFunction
# when we get user space polymorphic functions
# And then do similar lookup, and ensure we generate code for that variant
wgn.call(function.name)
def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Expression) -> None:
"""
@ -278,22 +356,7 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourl
expression_subscript_tuple(wgn, mod, inp)
return
inp_as_fc = ourlang.FunctionCall(
ourlang.FunctionInstance(
ourlang.BuiltinFunction('[]', mod.build.type_classes['Subscriptable'].operators['[]']),
inp.sourceref,
),
inp.sourceref,
)
inp_as_fc.arguments = [inp.varref, inp.index]
inp_as_fc.function_instance.type5 = mod.build.type5_make_function([
inp.varref.type5,
inp.index.type5,
inp.type5,
])
inp_as_fc.type5 = inp.type5
expression_function_call(wgn, mod, inp_as_fc)
expression_subscript_operator(wgn, mod, inp)
return
if isinstance(inp, ourlang.AccessStructMember):
@ -321,11 +384,11 @@ def statement_return(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun
# Support tail calls
# https://github.com/WebAssembly/tail-call
# These help a lot with some functional programming techniques
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function_instance.function is fun:
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function is fun:
for arg in inp.value.arguments:
expression(wgn, mod, arg)
wgn.add_statement('return_call', '${}'.format(inp.value.function_instance.function.name))
wgn.add_statement('return_call', '${}'.format(inp.value.function.name))
return
expression(wgn, mod, inp.value)
@ -346,6 +409,9 @@ def statement_if(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ou
# for stat in inp.else_statements:
# statement(wgn, stat)
def statement_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourlang.Function, inp: ourlang.StatementCall) -> None:
expression(wgn, mod, inp.call)
def statement(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourlang.Function, inp: ourlang.Statement) -> None:
"""
Compile: any statement
@ -358,6 +424,10 @@ def statement(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun: ourla
statement_if(wgn, mod, fun, inp)
return
if isinstance(inp, ourlang.StatementCall):
statement_call(wgn, mod, fun, inp)
return
if isinstance(inp, ourlang.StatementPass):
return
@ -595,14 +665,3 @@ def _type5_struct_offset(
result += build.type5_alloc_size_member(memtyp)
raise RuntimeError('Member not found')
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
"""
For compilation purposes, a binary operator is just a function call.
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
"""
assert inp.sourceref is not None # TODO: sourceref required
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
call.arguments = [inp.left, inp.right]
return call

View File

@ -157,52 +157,39 @@ class BinaryOp(Expression):
"""
A binary operator expression within a statement
"""
__slots__ = ('operator', 'left', 'right', )
__slots__ = ('operator', 'polytype_substitutions', 'left', 'right', )
operator: FunctionInstance
operator: Function | FunctionParam
polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]
left: Expression
right: Expression
def __init__(self, operator: FunctionInstance, left: Expression, right: Expression, sourceref: SourceRef) -> None:
def __init__(self, operator: Function | FunctionParam, left: Expression, right: Expression, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.operator = operator
self.polytype_substitutions = {}
self.left = left
self.right = right
def __repr__(self) -> str:
return f'BinaryOp({repr(self.operator)}, {repr(self.left)}, {repr(self.right)})'
class FunctionInstance(Expression):
"""
When calling a polymorphic function with concrete arguments, we can generate
code for that specific instance of the function.
"""
__slots__ = ('function', )
function: Union['Function', 'FunctionParam']
def __init__(self, function: Union['Function', 'FunctionParam'], sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.function = function
class FunctionCall(Expression):
"""
A function call expression within a statement
"""
__slots__ = ('function_instance', 'arguments', )
__slots__ = ('function', 'polytype_substitutions', 'arguments', )
function_instance: FunctionInstance
# TODO: FunctionInstance is wrong - we should have
# substitutions: dict[TypeVariable, TypeExpr]
# And it should have the same variables as the polytype (ConstrainedExpr) for function
function: Function | FunctionParam
polytype_substitutions: dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]
arguments: List[Expression]
def __init__(self, function_instance: FunctionInstance, sourceref: SourceRef) -> None:
def __init__(self, function: Function | FunctionParam, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.function_instance = function_instance
self.function = function
self.polytype_substitutions = {}
self.arguments = []
class FunctionReference(Expression):
@ -295,6 +282,24 @@ class StatementReturn(Statement):
def __repr__(self) -> str:
return f'StatementReturn({repr(self.value)})'
class StatementCall(Statement):
"""
A function call within a function.
Executing is deferred to the given function until it completes.
"""
__slots__ = ('call')
call: FunctionCall
def __init__(self, call: FunctionCall, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.call = call
def __repr__(self) -> str:
return f'StatementCall({repr(self.call)})'
class StatementIf(Statement):
"""
An if statement within a function

View File

@ -18,7 +18,6 @@ from .ourlang import (
Expression,
Function,
FunctionCall,
FunctionInstance,
FunctionParam,
FunctionReference,
Module,
@ -26,6 +25,7 @@ from .ourlang import (
ModuleDataBlock,
SourceRef,
Statement,
StatementCall,
StatementIf,
StatementPass,
StatementReturn,
@ -362,6 +362,9 @@ class OurVisitor[G]:
return result
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
return StatementCall(self.visit_Module_FunctionDef_Call(module, function, our_locals, node.value), srf(module, node))
if isinstance(node, ast.Pass):
return StatementPass(srf(module, node))
@ -400,7 +403,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'Operator {operator}')
return BinaryOp(
FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)),
BuiltinFunction(operator, module.operators[operator]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
srf(module, node),
@ -429,7 +432,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'Operator {operator}')
return BinaryOp(
FunctionInstance(BuiltinFunction(operator, module.operators[operator]), srf(module, node)),
BuiltinFunction(operator, module.operators[operator]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
srf(module, node),
@ -485,7 +488,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
def visit_Module_FunctionDef_Call(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Call) -> FunctionCall:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -506,7 +509,7 @@ class OurVisitor[G]:
func = module.functions[node.func.id]
result = FunctionCall(FunctionInstance(func, srf(module, node)), sourceref=srf(module, node))
result = FunctionCall(func, sourceref=srf(module, node))
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr in node.args
@ -648,6 +651,15 @@ class OurVisitor[G]:
for e in func_arg_types
])
if isinstance(node.value, ast.Name) and node.value.id == 'IO':
assert isinstance(node.slice, ast.Name) or (isinstance(node.slice, ast.Tuple) and len(node.slice.elts) == 0)
return module.build.type5_make_io(
self.visit_type5(module, node.slice)
)
# TODO: This u32[...] business is messing up the other type constructors
if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index')
@ -673,6 +685,9 @@ class OurVisitor[G]:
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.elts:
return module.build.unit_type5
return module.build.type5_make_tuple(
[self.visit_type5(module, elt) for elt in node.elts],
)

View File

@ -42,9 +42,8 @@ class ConstrainedExpr:
def instantiate_constrained(
constrainedexpr: ConstrainedExpr,
known_map: dict[TypeVariable, TypeVariable],
make_variable: Callable[[KindExpr, str], TypeVariable],
) -> ConstrainedExpr:
) -> tuple[ConstrainedExpr, dict[TypeVariable, TypeVariable]]:
"""
Instantiates a type expression and its constraints
"""
@ -61,4 +60,4 @@ def instantiate_constrained(
x.instantiate(known_map)
for x in constrainedexpr.constraints
)
return ConstrainedExpr(constrainedexpr.variables, expr, constraints)
return ConstrainedExpr(constrainedexpr.variables, expr, constraints), known_map

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import dataclasses
from typing import Any, Callable, Iterable, Protocol, Sequence
from typing import Any, Callable, Iterable, Protocol, Sequence, TypeAlias
from ..build.base import BuildBase
from ..ourlang import SourceRef
@ -9,13 +9,16 @@ from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt6
from .kindexpr import KindExpr, Star
from .record import Record
from .typeexpr import (
AtomicType,
TypeApplication,
TypeConstructor,
TypeExpr,
TypeLevelNat,
TypeVariable,
is_concrete,
occurs,
replace_variable,
)
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
class ExpressionProtocol(Protocol):
@ -28,50 +31,95 @@ class ExpressionProtocol(Protocol):
The type to update
"""
PolytypeSubsituteMap: TypeAlias = dict[TypeVariable, TypeExpr]
class Context:
__slots__ = ("build", "placeholder_update", )
__slots__ = ("build", "placeholder_update", "ptst_update", )
build: BuildBase[Any]
placeholder_update: dict[TypeVariable, ExpressionProtocol | None]
ptst_update: dict[TypeVariable, tuple[PolytypeSubsituteMap, TypeVariable]]
def __init__(self, build: BuildBase[Any]) -> None:
self.build = build
self.placeholder_update = {}
self.ptst_update = {}
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star(), prefix: str = 'p') -> TypeVariable:
res = TypeVariable(kind, f"{prefix}_{len(self.placeholder_update)}")
self.placeholder_update[res] = arg
return res
def register_polytype_subsitutes(self, tvar: TypeVariable, arg: PolytypeSubsituteMap, orig_var: TypeVariable) -> None:
"""
When `tvar` gets subsituted, also set the result in arg with orig_var as key
e.g.
(-) :: Callable[a, a, a]
def foo() -> u32:
return 2 - 1
During typing, we instantiate a into a_3, and get the following constraints:
- u8 ~ p_1
- u8 ~ p_2
- Exists NatNum a_3
- Callable[a_3, a_3, a_3] ~ Callable[p_1, p_2, p_0]
- u8 ~ p_0
When we resolve a_3, then on the call to `-`, we should note that a_3 got resolved
to u32. But we need to use `a` as key, since that's what's used on the definition
"""
assert tvar in self.placeholder_update
assert tvar not in self.ptst_update
self.ptst_update[tvar] = (arg, orig_var)
@dataclasses.dataclass
class Failure:
"""
Both types are already different - cannot be unified.
"""
msg: str
@dataclasses.dataclass
class ReplaceVariable:
var: TypeVariable
typ: TypeExpr
@dataclasses.dataclass
class CheckResult:
# TODO: Refactor this, don't think we use most of the variants
_: dataclasses.KW_ONLY
done: bool = True
actions: ActionList = dataclasses.field(default_factory=ActionList)
replace: ReplaceVariable | None = None
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
failures: list[Failure] = dataclasses.field(default_factory=list)
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
if not self.done and not self.actions and not self.new_constraints and not self.failures:
if not self.done and not self.replace and not self.new_constraints and not self.failures:
return '(skip for now)'
if self.done and not self.actions and not self.new_constraints and not self.failures:
if self.done and not self.replace and not self.new_constraints and not self.failures:
return '(ok)'
if self.done and self.actions and not self.new_constraints and not self.failures:
return self.actions.to_str(type_namer)
if self.done and self.replace and not self.new_constraints and not self.failures:
return f'{{{self.replace.var.name} := {type_namer(self.replace.typ)}}}'
if self.done and not self.actions and self.new_constraints and not self.failures:
if self.done and not self.replace and self.new_constraints and not self.failures:
return f'(got {len(self.new_constraints)} new constraints)'
if self.done and not self.actions and not self.new_constraints and self.failures:
if self.done and not self.replace and not self.new_constraints and self.failures:
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}'
return f'{self.done} {self.replace} {self.new_constraints} {self.failures}'
def skip_for_now() -> CheckResult:
return CheckResult(done=False)
def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult:
return CheckResult(replace=ReplaceVariable(var, typ))
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
return CheckResult(new_constraints=list(lst))
@ -94,12 +142,8 @@ class ConstraintBase:
def check(self) -> CheckResult:
raise NotImplementedError(self)
def apply(self, action: Action) -> None:
if isinstance(action, ReplaceVariable):
self.replace_variable(action.var, action.typ)
return
raise NotImplementedError(action)
def complexity(self) -> int:
raise NotImplementedError
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
pass
@ -142,6 +186,9 @@ class FromLiteralInteger(ConstraintBase):
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
@ -175,8 +222,11 @@ class FromLiteralFloat(ConstraintBase):
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
return f'FromLiteralFloat {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class FromLiteralBytes(ConstraintBase):
__slots__ = ('type5', 'literal', )
@ -203,32 +253,125 @@ class FromLiteralBytes(ConstraintBase):
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralBytes {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class UnifyTypesConstraint(ConstraintBase):
__slots__ = ("lft", "rgt",)
__slots__ = ("lft", "rgt", "prefix", )
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr) -> None:
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr, prefix: str | None = None) -> None:
super().__init__(ctx, sourceref)
self.lft = lft
self.rgt = rgt
self.prefix = prefix
def check(self) -> CheckResult:
result = unify(self.lft, self.rgt)
lft = self.lft
rgt = self.rgt
if isinstance(result, Failure):
return CheckResult(failures=[result])
if lft == self.rgt:
return ok()
return CheckResult(actions=result)
if lft.kind != rgt.kind:
return fail("Kind mismatch")
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
return fail("Not the same type")
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
return replace(rgt, lft)
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
return fail("Not the same type" if is_concrete(rgt) else "Type shape mismatch")
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
if occurs(lft, rgt):
return fail("One type occurs in the other")
return replace(lft, rgt)
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
return replace(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
return fail("Not the same type constructor")
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
return fail("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
return fail("Not the same type" if is_concrete(lft) else "Type shape mismatch")
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
if occurs(rgt, lft):
return fail("One type occurs in the other")
return replace(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
return fail("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
## USABILITY HACK
## Often, we have two type applications in the same go
## If so, resolve it in a single step
## (Helps with debugging function unification)
## This *should* not affect the actual type unification
## It's just one less call to UnifyTypesConstraint.check
if isinstance(lft.constructor, TypeApplication) and isinstance(rgt.constructor, TypeApplication):
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.constructor, rgt.constructor.constructor),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.argument, rgt.constructor.argument),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
])
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor, rgt.constructor),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
])
raise NotImplementedError(lft, rgt)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.lft = replace_variable(self.lft, var, typ)
self.rgt = replace_variable(self.rgt, var, typ)
def complexity(self) -> int:
return complexity(self.lft) + complexity(self.rgt)
def __str__(self) -> str:
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
prefix = f'{self.prefix} :: ' if self.prefix else ''
return f"{prefix}{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
class CanBeSubscriptedConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
@ -290,6 +433,9 @@ class CanBeSubscriptedConstraint(ConstraintBase):
self.container_type5 = replace_variable(self.container_type5, var, typ)
self.index_type5 = replace_variable(self.index_type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + complexity(self.container_type5) + complexity(self.index_type5)
def __str__(self) -> str:
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
@ -333,6 +479,9 @@ class CanAccessStructMemberConstraint(ConstraintBase):
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + complexity(self.struct_type5)
def __str__(self) -> str:
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
member_dict = dict(st_args or [])
@ -404,6 +553,9 @@ class FromTupleConstraint(ConstraintBase):
for x in self.member_type5_list
]
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + sum(complexity(x) for x in self.member_type5_list)
def __str__(self) -> str:
args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list)
return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )'
@ -450,6 +602,24 @@ class TypeClassInstanceExistsConstraint(ConstraintBase):
for x in self.arg_list
]
def complexity(self) -> int:
return 100 + sum(complexity(x) for x in self.arg_list)
def __str__(self) -> str:
args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list)
return f'Exists {self.typeclass} {args}'
def complexity(expr: TypeExpr) -> int:
if isinstance(expr, AtomicType | TypeLevelNat):
return 1
if isinstance(expr, TypeConstructor):
return 2
if isinstance(expr, TypeVariable):
return 5
if isinstance(expr, TypeApplication):
return complexity(expr.constructor) + complexity(expr.argument)
raise NotImplementedError(expr)

View File

@ -15,8 +15,8 @@ from .constraints import (
TypeClassInstanceExistsConstraint,
UnifyTypesConstraint,
)
from .kindexpr import KindExpr
from .typeexpr import TypeApplication, TypeVariable, instantiate
from .kindexpr import KindExpr, Star
from .typeexpr import TypeApplication, TypeExpr, TypeVariable, is_concrete
ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -90,14 +90,41 @@ def expression_constant(ctx: Context, inp: ourlang.Constant, phft: TypeVariable)
raise NotImplementedError(inp)
def expression_variable_reference(ctx: Context, inp: ourlang.VariableReference, phft: TypeVariable) -> ConstraintGenerator:
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft, prefix=inp.variable.name)
def expression_binary_operator(ctx: Context, inp: ourlang.BinaryOp, phft: TypeVariable) -> ConstraintGenerator:
yield from expression_function_call(ctx, _binary_op_to_function(inp), phft)
yield from _expression_binary_operator_or_function_call(
ctx,
inp.operator,
inp.polytype_substitutions,
[inp.left, inp.right],
inp.sourceref,
f'({inp.operator.name})',
phft,
)
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: TypeVariable) -> ConstraintGenerator:
yield from _expression_binary_operator_or_function_call(
ctx,
inp.function,
inp.polytype_substitutions,
inp.arguments,
inp.sourceref,
inp.function.name,
phft,
)
def _expression_binary_operator_or_function_call(
ctx: Context,
function: ourlang.Function | ourlang.FunctionParam,
polytype_substitutions: dict[TypeVariable, TypeExpr],
arguments: list[ourlang.Expression],
sourceref: ourlang.SourceRef,
function_name: str,
phft: TypeVariable,
) -> ConstraintGenerator:
arg_typ_list = []
for arg in inp.arguments:
for arg in arguments:
arg_tv = ctx.make_placeholder(arg)
yield from expression(ctx, arg, arg_tv)
arg_typ_list.append(arg_tv)
@ -105,34 +132,28 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Type
def make_placeholder(x: KindExpr, p: str) -> TypeVariable:
return ctx.make_placeholder(kind=x, prefix=p)
ftp5 = inp.function_instance.function.type5
ftp5 = function.type5
assert ftp5 is not None
if isinstance(ftp5, ConstrainedExpr):
ftp5 = instantiate_constrained(ftp5, {}, make_placeholder)
ftp5, phft_lookup = instantiate_constrained(ftp5, make_placeholder)
for orig_tvar, tvar in phft_lookup.items():
ctx.register_polytype_subsitutes(tvar, polytype_substitutions, orig_tvar)
for type_constraint in ftp5.constraints:
if isinstance(type_constraint, TypeClassConstraint):
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, type_constraint.cls.name, type_constraint.variables)
yield TypeClassInstanceExistsConstraint(ctx, sourceref, type_constraint.cls.name, type_constraint.variables)
continue
raise NotImplementedError(type_constraint)
ftp5 = ftp5.expr
else:
ftp5 = instantiate(ftp5, {})
# We need an extra placeholder so that the inp.function_instance gets updated
phft2 = ctx.make_placeholder(inp.function_instance)
yield UnifyTypesConstraint(
ctx,
inp.sourceref,
ftp5,
phft2,
)
assert is_concrete(ftp5)
expr_type = ctx.build.type5_make_function(arg_typ_list + [phft])
yield UnifyTypesConstraint(ctx, inp.sourceref, phft2, expr_type)
yield UnifyTypesConstraint(ctx, sourceref, ftp5, expr_type, prefix=function_name)
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: TypeVariable) -> ConstraintGenerator:
assert inp.function.type5 is not None # Todo: Make not nullable
@ -141,7 +162,7 @@ def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference,
if isinstance(ftp5, ConstrainedExpr):
ftp5 = ftp5.expr
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft, prefix=inp.function.name)
def expression_tuple_instantiation(ctx: Context, inp: ourlang.TupleInstantiation, phft: TypeVariable) -> ConstraintGenerator:
arg_typ_list = []
@ -212,15 +233,16 @@ def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement
if fun.type5 is None:
raise NotImplementedError("Deducing function type - you'll have to annotate it.")
if isinstance(fun.type5, TypeApplication):
args = ctx.build.type5_is_function(fun.type5)
assert args is not None
type5 = args[-1]
else:
type5 = fun.type5.expr if isinstance(fun.type5, ConstrainedExpr) else fun.type5
args = ctx.build.type5_is_function(fun.type5)
assert args is not None
type5 = args[-1]
# This is a hack to allow return statement in pure and non pure functions
if (io_arg := ctx.build.type5_is_io(type5)):
type5 = io_arg
yield from expression(ctx, inp.value, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft, prefix=f'{fun.name} returns')
def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf) -> ConstraintGenerator:
test_phft = ctx.make_placeholder(inp.test)
@ -235,6 +257,27 @@ def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf)
for stmt in inp.else_statements:
yield from statement(ctx, fun, stmt)
def statement_call(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementCall) -> ConstraintGenerator:
if fun.type5 is None:
raise NotImplementedError("Deducing function type - you'll have to annotate it.")
fn_args = ctx.build.type5_is_function(fun.type5)
assert fn_args is not None
fn_ret = fn_args[-1]
call_phft = ctx.make_placeholder(inp.call)
S = Star()
t_phft = ctx.make_placeholder(kind=S >> S)
a_phft = ctx.make_placeholder(kind=S)
yield from expression_function_call(ctx, inp.call, call_phft)
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, 'Monad', [t_phft])
yield UnifyTypesConstraint(ctx, inp.sourceref, TypeApplication(constructor=t_phft, argument=a_phft), fn_ret)
yield UnifyTypesConstraint(ctx, inp.sourceref, TypeApplication(constructor=t_phft, argument=ctx.build.unit_type5), call_phft)
def statement(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement) -> ConstraintGenerator:
if isinstance(inp, ourlang.StatementReturn):
yield from statement_return(ctx, fun, inp)
@ -244,12 +287,18 @@ def statement(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement) -> Co
yield from statement_if(ctx, fun, inp)
return
if isinstance(inp, ourlang.StatementCall):
yield from statement_call(ctx, fun, inp)
return
raise NotImplementedError(inp)
def function(ctx: Context, inp: ourlang.Function) -> ConstraintGenerator:
for stmt in inp.statements:
yield from statement(ctx, inp, stmt)
# TODO: If function is imported or exported, it should be an IO[..] function
def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> ConstraintGenerator:
phft = ctx.make_placeholder(inp.constant)
@ -267,14 +316,3 @@ def module(ctx: Context, inp: ourlang.Module[Any]) -> ConstraintGenerator:
yield from function(ctx, func)
# TODO: Generalize?
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
"""
For typing purposes, a binary operator is just a function call.
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
"""
assert inp.sourceref is not None # TODO: sourceref required
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
call.arguments = [inp.left, inp.right]
return call

View File

@ -3,8 +3,7 @@ from typing import Any
from ..ourlang import Module
from .constraints import ConstraintBase, Context
from .fromast import phasm_type5_generate_constraints
from .typeexpr import TypeExpr, TypeVariable, replace_variable
from .unify import ReplaceVariable
from .typeexpr import TypeExpr, TypeVariable, is_concrete, replace_variable
MAX_RESTACK_COUNT = 100
@ -31,8 +30,7 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
print("Validating")
new_constraint_list: list[ConstraintBase] = []
while constraint_list:
constraint = constraint_list.pop(0)
for constraint in sorted(constraint_list, key=lambda x: x.complexity()):
result = constraint.check()
if verbose:
@ -44,61 +42,41 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
# Means it checks out and we don't need do anything
continue
while result.actions:
action = result.actions.pop(0)
if result.replace is not None:
action_var = result.replace.var
assert action_var not in placeholder_types # When does this happen?
if isinstance(action, ReplaceVariable):
action_var: TypeExpr = action.var
while isinstance(action_var, TypeVariable) and action_var in placeholder_types:
# TODO: Does this still happen?
action_var = placeholder_types[action_var]
action_typ: TypeExpr = result.replace.typ
assert not isinstance(action_typ, TypeVariable) or action_typ not in placeholder_types # When does this happen?
action_typ: TypeExpr = action.typ
while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types:
# TODO: Does this still happen?
action_typ = placeholder_types[action_typ]
assert action_var != action_typ # When does this happen?
# print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ))
# Ensure all existing found types are updated
placeholder_types = {
k: replace_variable(v, action_var, action_typ)
for k, v in placeholder_types.items()
}
placeholder_types[action_var] = action_typ
if action_var == action_typ:
for oth_const in new_constraint_list + constraint_list:
if oth_const is constraint and result.done:
continue
if not isinstance(action_var, TypeVariable) and isinstance(action_typ, TypeVariable):
action_typ, action_var = action_var, action_typ
old_str = str(oth_const)
oth_const.replace_variable(action_var, action_typ)
new_str = str(oth_const)
if isinstance(action_var, TypeVariable):
# Ensure all existing found types are updated
placeholder_types = {
k: replace_variable(v, action_var, action_typ)
for k, v in placeholder_types.items()
}
placeholder_types[action_var] = action_typ
for oth_const in new_constraint_list + constraint_list:
if oth_const is constraint and result.done:
continue
old_str = str(oth_const)
oth_const.replace_variable(action_var, action_typ)
new_str = str(oth_const)
if verbose and old_str != new_str:
print(f"{oth_const.sourceref!s} => - {old_str!s}")
print(f"{oth_const.sourceref!s} => + {new_str!s}")
continue
error_list.append((str(constraint.sourceref), str(constraint), "Not the same type", ))
if verbose:
print(f"{constraint.sourceref!s} => ERR: Conflict in applying {action.to_str(inp.build.type5_name)}")
continue
# Action of unsupported type
raise NotImplementedError(action)
if verbose and old_str != new_str:
print(f"{oth_const.sourceref!s} => - {old_str!s}")
print(f"{oth_const.sourceref!s} => + {new_str!s}")
for failure in result.failures:
error_list.append((str(constraint.sourceref), str(constraint), failure.msg, ))
new_constraint_list.extend(result.new_constraints)
if verbose:
for new_const in result.new_constraints:
print(f"{oth_const.sourceref!s} => + {new_const!s}")
if result.done:
continue
@ -124,8 +102,11 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
if expression is None:
continue
new_type5 = placeholder_types[placeholder]
while isinstance(new_type5, TypeVariable):
new_type5 = placeholder_types[new_type5]
resolved_type5 = placeholder_types[placeholder]
assert is_concrete(resolved_type5) # When does this happen?
expression.type5 = resolved_type5
expression.type5 = new_type5
for placeholder, (ptst_map, orig_tvar) in ctx.ptst_update.items():
resolved_type5 = placeholder_types[placeholder]
assert is_concrete(resolved_type5) # When does this happen?
ptst_map[orig_tvar] = resolved_type5

View File

@ -4,6 +4,7 @@ from .typeexpr import (
TypeApplication,
TypeConstructor,
TypeExpr,
TypeLevelNat,
TypeVariable,
)
@ -21,6 +22,9 @@ class TypeRouter[T]:
def when_record(self, typ: Record) -> T:
raise NotImplementedError(typ)
def when_type_level_nat(self, typ: TypeLevelNat) -> T:
raise NotImplementedError(typ)
def when_variable(self, typ: TypeVariable) -> T:
raise NotImplementedError(typ)
@ -37,6 +41,9 @@ class TypeRouter[T]:
if isinstance(typ, TypeConstructor):
return self.when_constructor(typ)
if isinstance(typ, TypeLevelNat):
return self.when_type_level_nat(typ)
if isinstance(typ, TypeVariable):
return self.when_variable(typ)

View File

@ -1,128 +0,0 @@
from dataclasses import dataclass
from typing import Callable
from .typeexpr import (
AtomicType,
TypeApplication,
TypeConstructor,
TypeExpr,
TypeVariable,
is_concrete,
occurs,
)
@dataclass
class Failure:
"""
Both types are already different - cannot be unified.
"""
msg: str
@dataclass
class Action:
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
raise NotImplementedError
class ActionList(list[Action]):
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return '{' + ', '.join((x.to_str(type_namer) for x in self)) + '}'
UnifyResult = Failure | ActionList
@dataclass
class ReplaceVariable(Action):
var: TypeVariable
typ: TypeExpr
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return f'{self.var.name} := {type_namer(self.typ)}'
def unify(lft: TypeExpr, rgt: TypeExpr) -> UnifyResult:
"""
Be warned: This only matches type variables with other variables or types
- it does not apply substituions nor does it validate if the matching
pairs are correct.
TODO: Remove this. It should be part of UnifyTypesConstraint
and should just generate new constraints for applications.
"""
if lft == rgt:
return ActionList()
if lft.kind != rgt.kind:
return Failure("Kind mismatch")
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
return Failure("Not the same type")
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
return ActionList([ReplaceVariable(rgt, lft)])
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
if is_concrete(rgt):
return Failure("Not the same type")
return Failure("Type shape mismatch")
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
if occurs(lft, rgt):
return Failure("One type occurs in the other")
return ActionList([ReplaceVariable(lft, rgt)])
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
return unify(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
return Failure("Not the same type constructor")
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
return Failure("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
return unify(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
con_res = unify(lft.constructor, rgt.constructor)
if isinstance(con_res, Failure):
return con_res
arg_res = unify(lft.argument, rgt.argument)
if isinstance(arg_res, Failure):
return arg_res
return ActionList(con_res + arg_res)
return Failure('Not implemented')

View File

@ -139,6 +139,10 @@ class Extractor(BuildTypeRouter[ExtractorFunc]):
return DynamicArrayExtractor(self.access, self(da_arg))
def when_io(self, io_arg: TypeExpr) -> ExtractorFunc:
# IO is a type only annotation, it is not related to allocation
return self(io_arg)
def when_static_array(self, sa_len: int, sa_typ: TypeExpr) -> ExtractorFunc:
return StaticArrayExtractor(self.access, sa_len, self(sa_typ))

View File

@ -305,8 +305,5 @@ def testEntry() -> i32:
```
```py
if TYPE_NAME.startswith('tuple_') or TYPE_NAME.startswith('static_array_') or TYPE_NAME.startswith('dynamic_array_'):
expect_type_error('Not the same type constructor')
else:
expect_type_error('Not the same type')
expect_type_error('Not the same type')
```

View File

@ -1,5 +1,7 @@
import pytest
from phasm.type5.solver import Type5SolverException
from ..helpers import Suite
@ -38,24 +40,22 @@ def test_call_post_defined():
code_py = """
@exported
def testEntry() -> i32:
return helper(10, 3)
return helper(13)
def helper(left: i32, right: i32) -> i32:
return left - right
def helper(left: i32) -> i32:
return left
"""
result = Suite(code_py).run_code()
assert 7 == result.returned_value
assert 13 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.skip('FIXME: Type checking')
def test_call_invalid_type():
code_py = """
def helper(left: i32) -> i32:
return left()
"""
result = Suite(code_py).run_code()
assert 7 == result.returned_value
with pytest.raises(Type5SolverException, match=r'i32 ~ Callable\[i32\]'):
Suite(code_py).run_code()

View File

@ -91,8 +91,7 @@ def testEntry() -> i32:
return action(double, 13.0)
"""
match = r'Callable\[i32, i32\] ~ Callable\[f32, [^]]+\]'
with pytest.raises(Type5SolverException, match=match):
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@ -109,8 +108,7 @@ def testEntry() -> i32:
return action(double, 13)
"""
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[f32, i32\], p_[0-9]+, [^]]+\]'
with pytest.raises(Type5SolverException, match=match):
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@ -127,14 +125,14 @@ def testEntry() -> f32:
return action(double, 13)
"""
with pytest.raises(Type5SolverException, match='f32 ~ i32'):
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_function_with_wrong_return_type_pass():
code_py = """
def double(left: i32) -> f32:
return convert(left) * 2
return convert(left) * 2.0
def action(applicable: Callable[i32, i32], left: i32) -> i32:
return applicable(left)
@ -144,8 +142,7 @@ def testEntry() -> i32:
return action(double, 13)
"""
match = r'Callable\[Callable\[i32, i32\], i32, i32\] ~ Callable\[Callable\[i32, f32\], p_[0-9]+, [^]]+\]'
with pytest.raises(Type5SolverException, match=match):
with pytest.raises(Type5SolverException, match='i32 ~ f32'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@ -179,12 +176,12 @@ def testEntry() -> i32:
return action(double, 13, 14)
"""
match = r'Callable\[Callable\[i32, i32, i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]'
match = r'Callable\[i32, i32\] ~ i32'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_too_many_args_use():
def test_sof_too_many_args_use_0():
code_py = """
def thirteen() -> i32:
return 13
@ -197,12 +194,30 @@ def testEntry() -> i32:
return action(thirteen, 13)
"""
match = r'Callable\[i32\] ~ Callable\[i32, p_[0-9]+\]'
match = r'\(\) ~ i32'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code(verbose=True)
@pytest.mark.integration_test
def test_sof_too_many_args_pass():
def test_sof_too_many_args_use_1():
code_py = """
def thirteen(x: i32) -> i32:
return x
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(thirteen, 13, 26)
"""
match = r'i32 ~ Callable\[i32, i32\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code(verbose=True)
@pytest.mark.integration_test
def test_sof_too_many_args_pass_0():
code_py = """
def double(left: i32) -> i32:
return left * 2
@ -215,6 +230,24 @@ def testEntry() -> i32:
return action(double, 13, 14)
"""
match = r'Callable\[Callable\[i32\], i32, i32, i32\] ~ Callable\[Callable\[i32, i32\], p_[0-9]+, p_[0-9]+, p_[0-9]+\]'
match = r'\(\) ~ i32'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_too_many_args_pass_1():
code_py = """
def double(left: i32, right: i32) -> i32:
return left * right
def action(applicable: Callable[i32, i32], left: i32, right: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13, 14)
"""
match = r'i32 ~ Callable\[i32, i32\]'
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()

View File

@ -168,12 +168,9 @@ def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
return foldl(x, y, z)
"""
match = {
'i8': 'Type shape mismatch',
'i8[3]': 'Kind mismatch',
}
match = 'Type shape mismatch'
with pytest.raises(Type5SolverException, match=match[in_typ]):
with pytest.raises(Type5SolverException, match=match):
Suite(code_py).run_code()
@pytest.mark.integration_test

View File

@ -0,0 +1,38 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
def test_io_use_type_class():
code_py = f"""
@exported
def testEntry() -> IO[u32]:
return 4
"""
result = Suite(code_py).run_code()
assert 4 == result.returned_value
@pytest.mark.integration_test
def test_io_call_io_function():
code_py = f"""
@imported
def log(val: u32) -> IO[()]:
pass
@exported
def testEntry() -> IO[u32]:
log(123)
return 4
"""
log_history: list[Any] = []
def my_log(val: int) -> None:
log_history.append(val)
result = Suite(code_py).run_code(imports={
'log': my_log,
})
assert 4 == result.returned_value
assert [123] == log_history