phasm/phasm/type5/constraints.py
Johan B.W. de Vries 38e43944c7 Replaces type3 with type5
type5 is much more first principles based, so we get a lot
of weird quirks removed:

- FromLiteral no longer needs to understand AST
- Type unifications works more like Haskell
- Function types are just ordinary types, saving a lot of
  manual busywork

and more.
2025-08-10 14:51:17 +02:00

459 lines
16 KiB
Python

from __future__ import annotations
import dataclasses
from typing import Any, Callable, Iterable, Protocol, Sequence
from ..build.base import BuildBase
from ..ourlang import SourceRef
from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt64
from .kindexpr import KindExpr, Star
from .record import Record
from .typeexpr import (
TypeExpr,
TypeVariable,
is_concrete,
replace_variable,
)
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
class ExpressionProtocol(Protocol):
"""
A protocol for classes that should be updated on substitution
"""
type5: TypeExpr | None
"""
The type to update
"""
class Context:
__slots__ = ("build", "placeholder_update", )
build: BuildBase[Any]
placeholder_update: dict[TypeVariable, ExpressionProtocol | None]
def __init__(self, build: BuildBase[Any]) -> None:
self.build = build
self.placeholder_update = {}
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star(), prefix: str = 'p') -> TypeVariable:
res = TypeVariable(kind, f"{prefix}_{len(self.placeholder_update)}")
self.placeholder_update[res] = arg
return res
@dataclasses.dataclass
class CheckResult:
_: dataclasses.KW_ONLY
done: bool = True
actions: ActionList = dataclasses.field(default_factory=ActionList)
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
failures: list[Failure] = dataclasses.field(default_factory=list)
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
if not self.done and not self.actions and not self.new_constraints and not self.failures:
return '(skip for now)'
if self.done and not self.actions and not self.new_constraints and not self.failures:
return '(ok)'
if self.done and self.actions and not self.new_constraints and not self.failures:
return self.actions.to_str(type_namer)
if self.done and not self.actions and self.new_constraints and not self.failures:
return f'(got {len(self.new_constraints)} new constraints)'
if self.done and not self.actions and not self.new_constraints and self.failures:
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}'
def skip_for_now() -> CheckResult:
return CheckResult(done=False)
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
return CheckResult(new_constraints=list(lst))
def ok() -> CheckResult:
return CheckResult(done=True)
def fail(msg: str) -> CheckResult:
return CheckResult(failures=[Failure(msg)])
class ConstraintBase:
__slots__ = ("ctx", "sourceref", )
ctx: Context
sourceref: SourceRef
def __init__(self, ctx: Context, sourceref: SourceRef) -> None:
self.ctx = ctx
self.sourceref = sourceref
def check(self) -> CheckResult:
raise NotImplementedError(self)
def apply(self, action: Action) -> None:
if isinstance(action, ReplaceVariable):
self.replace_variable(action.var, action.typ)
return
raise NotImplementedError(action)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
pass
class FromLiteralInteger(ConstraintBase):
__slots__ = ('type5', 'literal', )
type5: TypeExpr
literal: int
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: int) -> None:
super().__init__(ctx, sourceref)
self.type5 = type5
self.literal = literal
def check(self) -> CheckResult:
if not is_concrete(self.type5):
return skip_for_now()
type_info = self.ctx.build.type_info_map.get(self.type5.name)
if type_info is None:
return fail('Cannot convert from literal integer')
if type_info.wasm_type is not WasmTypeInt32 and type_info.wasm_type is not WasmTypeInt64:
return fail('Cannot convert from literal integer')
assert type_info.signed is not None # type hint
try:
self.literal.to_bytes(type_info.alloc_size, 'big', signed=type_info.signed)
except OverflowError:
return fail(f'Must fit in {type_info.alloc_size} byte(s)')
return ok()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def __str__(self) -> str:
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class FromLiteralFloat(ConstraintBase):
__slots__ = ('type5', 'literal', )
type5: TypeExpr
literal: float
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: float) -> None:
super().__init__(ctx, sourceref)
self.type5 = type5
self.literal = literal
def check(self) -> CheckResult:
if not is_concrete(self.type5):
return skip_for_now()
type_info = self.ctx.build.type_info_map.get(self.type5.name)
if type_info is None:
return fail('Cannot convert from literal float')
if type_info.wasm_type is not WasmTypeFloat32 and type_info.wasm_type is not WasmTypeFloat64:
return fail('Cannot convert from literal float')
# TODO: Precision check
return ok()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def __str__(self) -> str:
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class FromLiteralBytes(ConstraintBase):
__slots__ = ('type5', 'literal', )
type5: TypeExpr
literal: bytes
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: bytes) -> None:
super().__init__(ctx, sourceref)
self.type5 = type5
self.literal = literal
def check(self) -> CheckResult:
if not is_concrete(self.type5):
return skip_for_now()
da_arg = self.ctx.build.type5_is_dynamic_array(self.type5)
if da_arg is None or da_arg != self.ctx.build.u8_type5:
return fail('Cannot convert from literal bytes')
return ok()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def __str__(self) -> str:
return f'FromLiteralBytes {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class UnifyTypesConstraint(ConstraintBase):
__slots__ = ("lft", "rgt",)
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr) -> None:
super().__init__(ctx, sourceref)
self.lft = lft
self.rgt = rgt
def check(self) -> CheckResult:
result = unify(self.lft, self.rgt)
if isinstance(result, Failure):
return CheckResult(failures=[result])
return CheckResult(actions=result)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.lft = replace_variable(self.lft, var, typ)
self.rgt = replace_variable(self.rgt, var, typ)
def __str__(self) -> str:
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
class CanBeSubscriptedConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
ret_type5: TypeExpr
container_type5: TypeExpr
index_type5: TypeExpr
index_const: int | None
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
container_type5: TypeExpr,
index_type5: TypeExpr,
index_const: int | None,
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.container_type5 = container_type5
self.index_type5 = index_type5
self.index_const = index_const
def check(self) -> CheckResult:
if not is_concrete(self.container_type5):
return CheckResult(done=False)
da_arg = self.ctx.build.type5_is_dynamic_array(self.container_type5)
if da_arg is not None:
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, da_arg, self.ret_type5),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
])
sa_args = self.ctx.build.type5_is_static_array(self.container_type5)
if sa_args is not None:
sa_len, sa_typ = sa_args
if self.index_const is not None and (self.index_const < 0 or sa_len <= self.index_const):
return fail('Tuple index out of range')
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, self.ret_type5),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
])
tp_args = self.ctx.build.type5_is_tuple(self.container_type5)
if tp_args is not None:
if self.index_const is None:
return fail('Must index with integer literal')
if self.index_const < 0 or len(tp_args) <= self.index_const:
return fail('Tuple index out of range')
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, tp_args[self.index_const], self.ret_type5),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
])
raise NotImplementedError('This should be converted to a TypeClass, also for da and sa, only tuple should be the exception')
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.container_type5 = replace_variable(self.container_type5, var, typ)
self.index_type5 = replace_variable(self.index_type5, var, typ)
def __str__(self) -> str:
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
class CanAccessStructMemberConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'struct_type5', 'member_name', )
ret_type5: TypeExpr
struct_type5: TypeExpr
member_name: str
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
struct_type5: TypeExpr,
member_name: str,
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.struct_type5 = struct_type5
self.member_name = member_name
def check(self) -> CheckResult:
if not is_concrete(self.struct_type5):
return CheckResult(done=False)
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
if st_args is None:
return fail('Must be a struct')
member_dict = dict(st_args)
if self.member_name not in member_dict:
return fail('Must have a field with this name')
return UnifyTypesConstraint(self.ctx, self.sourceref, self.ret_type5, member_dict[self.member_name]).check()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
def __str__(self) -> str:
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
member_dict = dict(st_args or [])
member_typ = member_dict.get(self.member_name)
if member_typ is None:
expect = 'a -> b'
else:
expect = f'{self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(member_typ)}'
return f".{self.member_name} :: {expect} ~ {self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
class FromTupleConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'member_type5_list', )
ret_type5: TypeExpr
member_type5_list: list[TypeExpr]
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
member_type5_list: Sequence[TypeExpr],
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.member_type5_list = list(member_type5_list)
def check(self) -> CheckResult:
if not is_concrete(self.ret_type5):
return CheckResult(done=False)
da_arg = self.ctx.build.type5_is_dynamic_array(self.ret_type5)
if da_arg is not None:
return CheckResult(new_constraints=[
UnifyTypesConstraint(self.ctx, self.sourceref, da_arg, x)
for x in self.member_type5_list
])
sa_args = self.ctx.build.type5_is_static_array(self.ret_type5)
if sa_args is not None:
sa_len, sa_typ = sa_args
if sa_len != len(self.member_type5_list):
return fail('Tuple element count mismatch')
return CheckResult(new_constraints=[
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, x)
for x in self.member_type5_list
])
tp_args = self.ctx.build.type5_is_tuple(self.ret_type5)
if tp_args is not None:
if len(tp_args) != len(self.member_type5_list):
return fail('Tuple element count mismatch')
return CheckResult(new_constraints=[
UnifyTypesConstraint(self.ctx, self.sourceref, act_typ, exp_typ)
for act_typ, exp_typ in zip(tp_args, self.member_type5_list, strict=True)
])
raise NotImplementedError(self.ret_type5)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.member_type5_list = [
replace_variable(x, var, typ)
for x in self.member_type5_list
]
def __str__(self) -> str:
args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list)
return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )'
class TypeClassInstanceExistsConstraint(ConstraintBase):
__slots__ = ('typeclass', 'arg_list', )
typeclass: str
arg_list: list[TypeExpr]
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
typeclass: str,
arg_list: Sequence[TypeExpr]
) -> None:
super().__init__(ctx, sourceref)
self.typeclass = typeclass
self.arg_list = list(arg_list)
def check(self) -> CheckResult:
c_arg_list = [
x for x in self.arg_list if is_concrete(x)
]
if len(c_arg_list) != len(self.arg_list):
return skip_for_now()
if any(isinstance(x, Record) for x in c_arg_list):
# TODO: Allow users to implement type classes on their structs
return fail('Missing type class instance')
key = tuple(c_arg_list)
existing_instances = self.ctx.build.type_class_instances[self.typeclass]
if key in existing_instances:
return ok()
return fail('Missing type class instance')
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.arg_list = [
replace_variable(x, var, typ)
for x in self.arg_list
]
def __str__(self) -> str:
args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list)
return f'Exists {self.typeclass} {args}'