phasm/phasm/type5/constraints.py
Johan B.W. de Vries 21cea93c3c Notes
2025-08-02 13:28:23 +02:00

500 lines
17 KiB
Python

from __future__ import annotations
import dataclasses
from typing import Any, Callable, Iterable, Protocol, Sequence
from ..build.base import BuildBase
from ..ourlang import ConstantStruct, ConstantTuple, SourceRef
from ..type3 import types as type3types
from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt64
from .kindexpr import KindExpr, Star
from .typeexpr import (
AtomicType,
TypeExpr,
TypeVariable,
is_concrete,
replace_variable,
)
from .record import Record
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
class ExpressionProtocol(Protocol):
"""
A protocol for classes that should be updated on substitution
"""
type5: TypeExpr | None
"""
The type to update
"""
class Context:
__slots__ = ("build", "placeholder_update", )
build: BuildBase[Any]
placeholder_update: dict[TypeVariable, ExpressionProtocol | None]
def __init__(self, build: BuildBase[Any]) -> None:
self.build = build
self.placeholder_update = {}
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star()) -> TypeVariable:
res = TypeVariable(kind, f"p_{len(self.placeholder_update)}")
self.placeholder_update[res] = arg
return res
@dataclasses.dataclass
class CheckResult:
_: dataclasses.KW_ONLY
done: bool = True
actions: ActionList = dataclasses.field(default_factory=ActionList)
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
failures: list[Failure] = dataclasses.field(default_factory=list)
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
if not self.done and not self.actions and not self.new_constraints and not self.failures:
return '(skip for now)'
if self.done and not self.actions and not self.new_constraints and not self.failures:
return '(ok)'
if self.done and self.actions and not self.new_constraints and not self.failures:
return self.actions.to_str(type_namer)
if self.done and not self.actions and self.new_constraints and not self.failures:
return f'(got {len(self.new_constraints)} new constraints)'
if self.done and not self.actions and not self.new_constraints and self.failures:
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}'
def skip_for_now() -> CheckResult:
return CheckResult(done=False)
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
return CheckResult(new_constraints=list(lst))
def ok() -> CheckResult:
return CheckResult(done=True)
def fail(msg: str) -> CheckResult:
return CheckResult(failures=[Failure(msg)])
class ConstraintBase:
__slots__ = ("ctx", "sourceref", "comment",)
ctx: Context
sourceref: SourceRef
comment: str | None
def __init__(self, ctx: Context, sourceref: SourceRef, comment: str | None = None) -> None:
self.ctx = ctx
self.sourceref = sourceref
self.comment = comment
def check(self) -> CheckResult:
raise NotImplementedError(self)
def apply(self, action: Action) -> None:
if isinstance(action, ReplaceVariable):
self.replace_variable(action.var, action.typ)
return
raise NotImplementedError(action)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
pass
class LiteralFitsConstraint(ConstraintBase):
__slots__ = ("type", "literal",)
def __init__(self, ctx: Context, sourceref: SourceRef, type: TypeExpr, literal: Any, *, comment: str | None = None) -> None:
super().__init__(ctx, sourceref, comment)
self.type = type
self.literal = literal
def check(self) -> CheckResult:
if not is_concrete(self.type):
return skip_for_now()
type_info = self.ctx.build.type_info_map.get(self.type.name)
if type_info is not None and (type_info.wasm_type is WasmTypeInt32 or type_info.wasm_type is WasmTypeInt64):
assert type_info.signed is not None
if not isinstance(self.literal.value, int):
return fail('Must be integer')
try:
self.literal.value.to_bytes(type_info.alloc_size, 'big', signed=type_info.signed)
except OverflowError:
return fail(f'Must fit in {type_info.alloc_size} byte(s)')
return ok()
if type_info is not None and (type_info.wasm_type is WasmTypeFloat32 or type_info.wasm_type is WasmTypeFloat64):
if isinstance(self.literal.value, float):
# FIXME: Bit check
return ok()
return fail('Must be real')
da_arg = self.ctx.build.type5_is_dynamic_array(self.type)
if da_arg is not None:
if da_arg == self.ctx.build.u8_type5:
if not isinstance(self.literal.value, bytes):
return fail('Must be bytes')
return ok()
if not isinstance(self.literal, ConstantTuple):
return fail('Must be tuple')
return new_constraints(
LiteralFitsConstraint(self.ctx, nod.sourceref, da_arg, nod)
for nod in self.literal.value
)
sa_args = self.ctx.build.type5_is_static_array(self.type)
if sa_args is not None:
sa_len, sa_typ = sa_args
if not isinstance(self.literal, ConstantTuple):
return fail('Must be tuple')
if len(self.literal.value) != sa_len:
return fail('Tuple element count mismatch')
return new_constraints(
LiteralFitsConstraint(self.ctx, nod.sourceref, sa_typ, nod)
for nod in self.literal.value
)
st_args = self.ctx.build.type5_is_record(self.type)
if st_args is not None:
if not isinstance(self.literal, ConstantStruct):
return fail('Must be struct')
if self.literal.struct_type3.name != self.type.name: # TODO: Name based check is wonky
return fail('Must be right struct')
if len(self.literal.value) != len(st_args):
return fail('Struct member count mismatch')
return new_constraints(
LiteralFitsConstraint(self.ctx, nod.sourceref, nod_typ, nod)
for nod, (_, nod_typ) in zip(self.literal.value, st_args, strict=True)
)
tp_args = self.ctx.build.type5_is_tuple(self.type)
if tp_args is not None:
if not isinstance(self.literal, ConstantTuple):
return fail('Must be tuple')
if len(self.literal.value) != len(tp_args):
return fail('Tuple element count mismatch')
return new_constraints(
LiteralFitsConstraint(self.ctx, nod.sourceref, nod_typ, nod)
for nod, nod_typ in zip(self.literal.value, tp_args, strict=True)
)
raise NotImplementedError(self.type, type_info)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type = replace_variable(self.type, var, typ)
def __str__(self) -> str:
return f"{self.ctx.build.type5_name(self.type)} can contain {self.literal!r}"
class UnifyTypesConstraint(ConstraintBase):
__slots__ = ("lft", "rgt",)
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr, *, comment: str | None = None) -> None:
super().__init__(ctx, sourceref, comment)
self.lft = lft
self.rgt = rgt
def check(self) -> CheckResult:
result = unify(self.lft, self.rgt)
if isinstance(result, Failure):
return CheckResult(failures=[result])
return CheckResult(actions=result)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.lft = replace_variable(self.lft, var, typ)
self.rgt = replace_variable(self.rgt, var, typ)
def __str__(self) -> str:
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
class CanBeSubscriptedConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
ret_type5: TypeExpr
container_type5: TypeExpr
index_type5: TypeExpr
index_const: int | None
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
container_type5: TypeExpr,
index_type5: TypeExpr,
index_const: int | None,
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.container_type5 = container_type5
self.index_type5 = index_type5
self.index_const = index_const
def check(self) -> CheckResult:
if not is_concrete(self.container_type5):
return CheckResult(done=False)
da_arg = self.ctx.build.type5_is_dynamic_array(self.container_type5)
if da_arg is not None:
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, da_arg, self.ret_type5),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
])
sa_args = self.ctx.build.type5_is_static_array(self.container_type5)
if sa_args is not None:
sa_len, sa_typ = sa_args
if self.index_const is not None and (self.index_const < 0 or sa_len <= self.index_const):
return fail('Tuple index out of range')
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, self.ret_type5),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
])
tp_args = self.ctx.build.type5_is_tuple(self.container_type5)
if tp_args is not None:
if self.index_const is None:
return fail('Must index with integer literal')
if self.index_const < 0 or len(tp_args) <= self.index_const:
return fail('Tuple index out of range')
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, tp_args[self.index_const], self.ret_type5),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
])
return new_constraints([
TypeClassInstanceExistsConstraint(self.ctx, self.sourceref, 'Subscriptable', [self.container_type5]),
])
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.container_type5 = replace_variable(self.container_type5, var, typ)
self.index_type5 = replace_variable(self.index_type5, var, typ)
def __str__(self) -> str:
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
class CanAccessStructMemberConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'struct_type5', 'member_name', )
ret_type5: TypeExpr
struct_type5: TypeExpr
member_name: str
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
struct_type5: TypeExpr,
member_name: str,
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.struct_type5 = struct_type5
self.member_name = member_name
def check(self) -> CheckResult:
if not is_concrete(self.struct_type5):
return CheckResult(done=False)
st_args = self.ctx.build.type5_is_record(self.struct_type5)
if st_args is None:
return fail('Must be a struct')
member_dict = dict(st_args)
if self.member_name not in member_dict:
return fail('Must have a field with this name')
return UnifyTypesConstraint(self.ctx, self.sourceref, self.ret_type5, member_dict[self.member_name]).check()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
def __str__(self) -> str:
st_args = self.ctx.build.type5_is_record(self.struct_type5)
member_dict = dict(st_args or [])
member_typ = member_dict.get(self.member_name)
if member_typ is None:
expect = 'a -> b'
else:
expect = f'{self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(member_typ)}'
return f".{self.member_name} :: {expect} ~ {self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
class FromTupleConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'member_type5_list', )
ret_type5: TypeExpr
member_type5_list: list[TypeExpr]
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
member_type5_list: Sequence[TypeExpr],
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.member_type5_list = list(member_type5_list)
def check(self) -> CheckResult:
if not is_concrete(self.ret_type5):
return CheckResult(done=False)
da_arg = self.ctx.build.type5_is_dynamic_array(self.ret_type5)
if da_arg is not None:
return CheckResult(new_constraints=[
UnifyTypesConstraint(self.ctx, self.sourceref, da_arg, x)
for x in self.member_type5_list
])
sa_args = self.ctx.build.type5_is_static_array(self.ret_type5)
if sa_args is not None:
sa_len, sa_typ = sa_args
if sa_len != len(self.member_type5_list):
return fail('Tuple element count mismatch')
return CheckResult(new_constraints=[
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, x)
for x in self.member_type5_list
])
tp_args = self.ctx.build.type5_is_tuple(self.ret_type5)
if tp_args is not None:
if len(tp_args) != len(self.member_type5_list):
return fail('Tuple element count mismatch')
return CheckResult(new_constraints=[
UnifyTypesConstraint(self.ctx, self.sourceref, act_typ, exp_typ)
for act_typ, exp_typ in zip(tp_args, self.member_type5_list, strict=True)
])
raise NotImplementedError(self.ret_type5)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.member_type5_list = [
replace_variable(x, var, typ)
for x in self.member_type5_list
]
def __str__(self) -> str:
args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list)
return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )'
class TypeClassInstanceExistsConstraint(ConstraintBase):
__slots__ = ('typeclass', 'arg_list', )
typeclass: str
arg_list: list[TypeExpr]
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
typeclass: str,
arg_list: Sequence[TypeExpr]
) -> None:
super().__init__(ctx, sourceref)
self.typeclass = typeclass
self.arg_list = list(arg_list)
def check(self) -> CheckResult:
c_arg_list = [
x for x in self.arg_list if is_concrete(x)
]
if len(c_arg_list) != len(self.arg_list):
return skip_for_now()
tcls = self.ctx.build.type_classes[self.typeclass]
# Temporary hack while we are converting from type3 to type5
try:
targs = tuple(
_type5_to_type3_or_type3_const(self.ctx.build, x)
for x in self.arg_list
)
except RecordFoundException:
return fail('Missing type class instance')
if (tcls, targs, ) in self.ctx.build.type_class_instances:
return ok()
return fail('Missing type class instance')
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.arg_list = [
replace_variable(x, var, typ)
for x in self.arg_list
]
def __str__(self) -> str:
args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list)
return f'Exists {self.typeclass} {args}'
class RecordFoundException(Exception):
pass
def _type5_to_type3_or_type3_const(build: BuildBase[Any], type5: TypeExpr) -> type3types.Type3 | type3types.TypeConstructor_Base[Any] :
if isinstance(type5, Record):
raise RecordFoundException
if isinstance(type5, AtomicType):
return build.types[type5.name]
da_arg5 = build.type5_is_dynamic_array(type5)
if da_arg5 is not None:
return build.dynamic_array
sa_arg5 = build.type5_is_static_array(type5)
if sa_arg5 is not None:
return build.static_array
tp_arg5 = build.type5_is_tuple(type5)
if tp_arg5 is not None:
return build.tuple_
raise NotImplementedError(type5)