phasm/phasm/type5/fromast.py
2025-08-23 15:33:23 +02:00

309 lines
11 KiB
Python

from typing import Any, Generator
from .. import ourlang
from ..typeclass import TypeClassConstraint
from .constrainedexpr import ConstrainedExpr, instantiate_constrained
from .constraints import (
CanAccessStructMemberConstraint,
CanBeSubscriptedConstraint,
ConstraintBase,
Context,
FromLiteralBytes,
FromLiteralFloat,
FromLiteralInteger,
FromTupleConstraint,
TypeClassInstanceExistsConstraint,
UnifyTypesConstraint,
)
from .kindexpr import KindExpr, Star
from .typeexpr import TypeApplication, TypeVariable, instantiate
ConstraintGenerator = Generator[ConstraintBase, None, None]
def phasm_type5_generate_constraints(ctx: Context, inp: ourlang.Module[Any]) -> list[ConstraintBase]:
return [*module(ctx, inp)]
def expression_constant_primitive(ctx: Context, inp: ourlang.ConstantPrimitive, phft: TypeVariable) -> ConstraintGenerator:
if isinstance(inp.value, int):
yield FromLiteralInteger(ctx, inp.sourceref, phft, inp.value)
return
if isinstance(inp.value, float):
yield FromLiteralFloat(ctx, inp.sourceref, phft, inp.value)
return
raise NotImplementedError(inp.value)
def expression_constant_bytes(ctx: Context, inp: ourlang.ConstantBytes, phft: TypeVariable) -> ConstraintGenerator:
yield FromLiteralBytes(ctx, inp.sourceref, phft, inp.value)
def expression_constant_tuple(ctx: Context, inp: ourlang.ConstantTuple, phft: TypeVariable) -> ConstraintGenerator:
member_type5_list = [
ctx.make_placeholder(arg)
for arg in inp.value
]
for arg, arg_phft in zip(inp.value, member_type5_list):
yield from expression(ctx, arg, arg_phft)
yield FromTupleConstraint(ctx, inp.sourceref, phft, member_type5_list)
def expression_constant_struct(ctx: Context, inp: ourlang.ConstantStruct, phft: TypeVariable) -> ConstraintGenerator:
member_type5_list = [
ctx.make_placeholder(arg)
for arg in inp.value
]
for arg, arg_phft in zip(inp.value, member_type5_list):
yield from expression(ctx, arg, arg_phft)
lft = ctx.build.type5_make_function([x[1] for x in inp.struct_type5.fields] + [inp.struct_type5])
rgt = ctx.build.type5_make_function(member_type5_list + [phft])
yield UnifyTypesConstraint(ctx, inp.sourceref, lft, rgt)
def expression_constant_memory_stored(ctx: Context, inp: ourlang.ConstantMemoryStored, phft: TypeVariable) -> ConstraintGenerator:
if isinstance(inp, ourlang.ConstantBytes):
yield from expression_constant_bytes(ctx, inp, phft)
return
if isinstance(inp, ourlang.ConstantTuple):
yield from expression_constant_tuple(ctx, inp, phft)
return
if isinstance(inp, ourlang.ConstantStruct):
yield from expression_constant_struct(ctx, inp, phft)
return
raise NotImplementedError(inp)
def expression_constant(ctx: Context, inp: ourlang.Constant, phft: TypeVariable) -> ConstraintGenerator:
if isinstance(inp, ourlang.ConstantPrimitive):
yield from expression_constant_primitive(ctx, inp, phft)
return
if isinstance(inp, ourlang.ConstantMemoryStored):
yield from expression_constant_memory_stored(ctx, inp, phft)
return
raise NotImplementedError(inp)
def expression_variable_reference(ctx: Context, inp: ourlang.VariableReference, phft: TypeVariable) -> ConstraintGenerator:
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.variable.type5, phft)
def expression_binary_operator(ctx: Context, inp: ourlang.BinaryOp, phft: TypeVariable) -> ConstraintGenerator:
yield from expression_function_call(ctx, _binary_op_to_function(inp), phft)
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: TypeVariable) -> ConstraintGenerator:
arg_typ_list = []
for arg in inp.arguments:
arg_tv = ctx.make_placeholder(arg)
yield from expression(ctx, arg, arg_tv)
arg_typ_list.append(arg_tv)
def make_placeholder(x: KindExpr, p: str) -> TypeVariable:
return ctx.make_placeholder(kind=x, prefix=p)
ftp5 = inp.function_instance.function.type5
assert ftp5 is not None
if isinstance(ftp5, ConstrainedExpr):
ftp5 = instantiate_constrained(ftp5, {}, make_placeholder)
for type_constraint in ftp5.constraints:
if isinstance(type_constraint, TypeClassConstraint):
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, type_constraint.cls.name, type_constraint.variables)
continue
raise NotImplementedError(type_constraint)
ftp5 = ftp5.expr
else:
ftp5 = instantiate(ftp5, {})
# We need an extra placeholder so that the inp.function_instance gets updated
phft2 = ctx.make_placeholder(inp.function_instance)
yield UnifyTypesConstraint(
ctx,
inp.sourceref,
ftp5,
phft2,
)
expr_type = ctx.build.type5_make_function(arg_typ_list + [phft])
yield UnifyTypesConstraint(ctx, inp.sourceref, phft2, expr_type)
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: TypeVariable) -> ConstraintGenerator:
assert inp.function.type5 is not None # Todo: Make not nullable
ftp5 = inp.function.type5
if isinstance(ftp5, ConstrainedExpr):
ftp5 = ftp5.expr
yield UnifyTypesConstraint(ctx, inp.sourceref, ftp5, phft)
def expression_tuple_instantiation(ctx: Context, inp: ourlang.TupleInstantiation, phft: TypeVariable) -> ConstraintGenerator:
arg_typ_list = []
for arg in inp.elements:
arg_tv = ctx.make_placeholder(arg)
yield from expression(ctx, arg, arg_tv)
arg_typ_list.append(arg_tv)
yield FromTupleConstraint(ctx, inp.sourceref, phft, arg_typ_list)
def expression_subscript(ctx: Context, inp: ourlang.Subscript, phft: TypeVariable) -> ConstraintGenerator:
varref_phft = ctx.make_placeholder(inp.varref)
index_phft = ctx.make_placeholder(inp.index)
yield from expression(ctx, inp.varref, varref_phft)
yield from expression(ctx, inp.index, index_phft)
if isinstance(inp.index, ourlang.ConstantPrimitive) and isinstance(inp.index.value, int):
yield CanBeSubscriptedConstraint(ctx, inp.sourceref, phft, varref_phft, index_phft, inp.index.value)
else:
yield CanBeSubscriptedConstraint(ctx, inp.sourceref, phft, varref_phft, index_phft, None)
def expression_access_struct_member(ctx: Context, inp: ourlang.AccessStructMember, phft: TypeVariable) -> ConstraintGenerator:
varref_phft = ctx.make_placeholder(inp.varref)
yield from expression_variable_reference(ctx, inp.varref, varref_phft)
yield CanAccessStructMemberConstraint(ctx, inp.sourceref, phft, varref_phft, inp.member)
def expression(ctx: Context, inp: ourlang.Expression, phft: TypeVariable) -> ConstraintGenerator:
if isinstance(inp, ourlang.Constant):
yield from expression_constant(ctx, inp, phft)
return
if isinstance(inp, ourlang.VariableReference):
yield from expression_variable_reference(ctx, inp, phft)
return
if isinstance(inp, ourlang.BinaryOp):
yield from expression_binary_operator(ctx, inp, phft)
return
if isinstance(inp, ourlang.FunctionCall):
yield from expression_function_call(ctx, inp, phft)
return
if isinstance(inp, ourlang.FunctionReference):
yield from expression_function_reference(ctx, inp, phft)
return
if isinstance(inp, ourlang.TupleInstantiation):
yield from expression_tuple_instantiation(ctx, inp, phft)
return
if isinstance(inp, ourlang.Subscript):
yield from expression_subscript(ctx, inp, phft)
return
if isinstance(inp, ourlang.AccessStructMember):
yield from expression_access_struct_member(ctx, inp, phft)
return
raise NotImplementedError(inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:
phft = ctx.make_placeholder(inp.value)
if fun.type5 is None:
raise NotImplementedError("Deducing function type - you'll have to annotate it.")
args = ctx.build.type5_is_function(fun.type5)
assert args is not None
type5 = args[-1]
# This is a hack to allow return statement in pure and non pure functions
if (io_arg := ctx.build.type5_is_io(type5)):
type5 = io_arg
yield from expression(ctx, inp.value, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, type5, phft)
def statement_if(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementIf) -> ConstraintGenerator:
test_phft = ctx.make_placeholder(inp.test)
yield from expression(ctx, inp.test, test_phft)
yield UnifyTypesConstraint(ctx, inp.test.sourceref, test_phft, ctx.build.bool_type5)
for stmt in inp.statements:
yield from statement(ctx, fun, stmt)
for stmt in inp.else_statements:
yield from statement(ctx, fun, stmt)
def statement_call(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementCall) -> ConstraintGenerator:
if fun.type5 is None:
raise NotImplementedError("Deducing function type - you'll have to annotate it.")
fn_args = ctx.build.type5_is_function(fun.type5)
assert fn_args is not None
fn_ret = fn_args[-1]
call_phft = ctx.make_placeholder(inp.call)
S = Star()
t_phft = ctx.make_placeholder(kind=S >> S)
a_phft = ctx.make_placeholder(kind=S)
yield from expression_function_call(ctx, inp.call, call_phft)
yield TypeClassInstanceExistsConstraint(ctx, inp.sourceref, 'Monad', [t_phft])
yield UnifyTypesConstraint(ctx, inp.sourceref, TypeApplication(constructor=t_phft, argument=a_phft), fn_ret)
yield UnifyTypesConstraint(ctx, inp.sourceref, TypeApplication(constructor=t_phft, argument=ctx.build.unit_type5), call_phft)
def statement(ctx: Context, fun: ourlang.Function, inp: ourlang.Statement) -> ConstraintGenerator:
if isinstance(inp, ourlang.StatementReturn):
yield from statement_return(ctx, fun, inp)
return
if isinstance(inp, ourlang.StatementIf):
yield from statement_if(ctx, fun, inp)
return
if isinstance(inp, ourlang.StatementCall):
yield from statement_call(ctx, fun, inp)
return
raise NotImplementedError(inp)
def function(ctx: Context, inp: ourlang.Function) -> ConstraintGenerator:
for stmt in inp.statements:
yield from statement(ctx, inp, stmt)
# TODO: If function is imported or exported, it should be an IO[..] function
def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> ConstraintGenerator:
phft = ctx.make_placeholder(inp.constant)
yield from expression_constant(ctx, inp.constant, phft)
yield UnifyTypesConstraint(ctx, inp.sourceref, inp.type5, phft)
def module(ctx: Context, inp: ourlang.Module[Any]) -> ConstraintGenerator:
for cdef in inp.constant_defs.values():
yield from module_constant_def(ctx, cdef)
for func in inp.functions.values():
if func.imported:
continue
yield from function(ctx, func)
# TODO: Generalize?
def _binary_op_to_function(inp: ourlang.BinaryOp) -> ourlang.FunctionCall:
"""
For typing purposes, a binary operator is just a function call.
It's only syntactic sugar - e.g. `1 + 2` vs `+(1, 2)`
"""
assert inp.sourceref is not None # TODO: sourceref required
call = ourlang.FunctionCall(inp.operator, inp.sourceref)
call.arguments = [inp.left, inp.right]
return call