phasm/phasm/type5/constraints.py
Johan B.W. de Vries 67def3dff7 Notes
2025-07-30 19:16:46 +02:00

356 lines
13 KiB
Python

from __future__ import annotations
import dataclasses
from typing import Any, Callable, Iterable, Protocol
from ..build.base import BuildBase
from ..ourlang import ConstantStruct, ConstantTuple, SourceRef
from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt64
from .kindexpr import KindExpr, Star
from .typeexpr import (
TypeExpr,
TypeVariable,
is_concrete,
replace_variable,
)
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
class ExpressionProtocol(Protocol):
"""
A protocol for classes that should be updated on substitution
"""
type5: TypeExpr | None
"""
The type to update
"""
class Context:
__slots__ = ("build", "placeholder_update", )
build: BuildBase[Any]
placeholder_update: dict[TypeVariable, ExpressionProtocol]
def __init__(self, build: BuildBase[Any]) -> None:
self.build = build
self.placeholder_update = {}
def make_placeholder(self, arg: ExpressionProtocol, kind: KindExpr = Star()) -> TypeVariable:
res = TypeVariable(kind, f"p_{len(self.placeholder_update)}")
self.placeholder_update[res] = arg
return res
@dataclasses.dataclass
class CheckResult:
_: dataclasses.KW_ONLY
done: bool = True
actions: ActionList = dataclasses.field(default_factory=ActionList)
failures: list[Failure] = dataclasses.field(default_factory=list)
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
if not self.done and not self.actions and not self.failures:
return '(skip for now)'
if self.done and not self.actions and not self.failures:
return '(ok)'
if self.done and self.actions and not self.failures:
return self.actions.to_str(type_namer)
if self.done and not self.actions and self.failures:
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
return f'{self.actions.to_str(type_namer)} {self.failures} {self.done}'
def fail(msg: str) -> CheckResult:
return CheckResult(failures=[Failure(msg)])
def combine_check_result(cr_list: Iterable[CheckResult]) -> CheckResult:
done = True
actions = ActionList()
failures = list()
for cr in cr_list:
done = done and cr.done
actions.extend(cr.actions)
failures.extend(cr.failures)
return CheckResult(done=done, actions=actions, failures=failures)
class ConstraintBase:
__slots__ = ("ctx", "sourceref", "comment",)
ctx: Context
sourceref: SourceRef | None
comment: str | None
def __init__(self, ctx: Context, sourceref: SourceRef | None, comment: str | None = None) -> None:
self.ctx = ctx
self.sourceref = sourceref
self.comment = comment
def check(self) -> CheckResult:
raise NotImplementedError(self)
def apply(self, action: Action) -> None:
if isinstance(action, ReplaceVariable):
self.replace_variable(action.var, action.typ)
return
raise NotImplementedError(action)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
pass
class LiteralFitsConstraint(ConstraintBase):
__slots__ = ("type", "literal",)
def __init__(self, ctx: Context, sourceref: SourceRef | None, type: TypeExpr, literal: Any, *, comment: str | None = None) -> None:
super().__init__(ctx, sourceref, comment)
self.type = type
self.literal = literal
def check(self) -> CheckResult:
if not is_concrete(self.type):
return CheckResult(done=False)
type_info = self.ctx.build.type_info_map.get(self.type.name)
if type_info is not None and (type_info.wasm_type is WasmTypeInt32 or type_info.wasm_type is WasmTypeInt64):
assert type_info.signed is not None
if not isinstance(self.literal.value, int):
return fail('Must be integer')
try:
self.literal.value.to_bytes(type_info.alloc_size, 'big', signed=type_info.signed)
except OverflowError:
return fail(f'Must fit in {type_info.alloc_size} byte(s)')
return CheckResult()
if type_info is not None and (type_info.wasm_type is WasmTypeFloat32 or type_info.wasm_type is WasmTypeFloat64):
if isinstance(self.literal.value, float):
# FIXME: Bit check
return CheckResult()
return fail('Must be real')
da_arg = self.ctx.build.type5_is_dynamic_array(self.type)
if da_arg is not None:
if da_arg == self.ctx.build.u8_type5:
if not isinstance(self.literal.value, bytes):
return fail('Must be bytes')
return CheckResult()
if not isinstance(self.literal, ConstantTuple):
return fail('Must be tuple')
return combine_check_result((
LiteralFitsConstraint(self.ctx, nod.sourceref, da_arg, nod).check()
for nod in self.literal.value
))
sa_args = self.ctx.build.type5_is_static_array(self.type)
if sa_args is not None:
sa_len, sa_typ = sa_args
if not isinstance(self.literal, ConstantTuple):
return fail('Must be tuple')
if len(self.literal.value) != sa_len:
return fail('Tuple element count mismatch')
return combine_check_result((
LiteralFitsConstraint(self.ctx, nod.sourceref, sa_typ, nod).check()
for nod in self.literal.value
))
st_args = self.ctx.build.type5_is_record(self.type)
if st_args is not None:
if not isinstance(self.literal, ConstantStruct):
return fail('Must be struct')
if self.literal.struct_type3.name != self.type.name: # TODO: Name based check is wonky
return fail('Must be right struct')
if len(self.literal.value) != len(st_args):
return fail('Struct member count mismatch')
return combine_check_result((
LiteralFitsConstraint(self.ctx, nod.sourceref, nod_typ, nod).check()
for nod, (_, nod_typ) in zip(self.literal.value, st_args, strict=True)
))
tp_args = self.ctx.build.type5_is_tuple(self.type)
if tp_args is not None:
if not isinstance(self.literal, ConstantTuple):
return fail('Must be tuple')
if len(self.literal.value) != len(tp_args):
return fail('Tuple element count mismatch')
return combine_check_result((
LiteralFitsConstraint(self.ctx, nod.sourceref, nod_typ, nod).check()
for nod, nod_typ in zip(self.literal.value, tp_args, strict=True)
))
raise NotImplementedError(self.type, type_info)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type = replace_variable(self.type, var, typ)
def __str__(self) -> str:
return f"{self.ctx.build.type5_name(self.type)} can contain {self.literal!r}"
class UnifyTypesConstraint(ConstraintBase):
__slots__ = ("lft", "rgt",)
def __init__(self, ctx: Context, sourceref: SourceRef | None, lft: TypeExpr, rgt: TypeExpr, *, comment: str | None = None) -> None:
super().__init__(ctx, sourceref, comment)
self.lft = lft
self.rgt = rgt
def check(self) -> CheckResult:
result = unify(self.lft, self.rgt)
if isinstance(result, Failure):
return CheckResult(failures=[result])
return CheckResult(actions=result)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.lft = replace_variable(self.lft, var, typ)
self.rgt = replace_variable(self.rgt, var, typ)
def __str__(self) -> str:
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
class CanBeSubscriptedConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
ret_type5: TypeExpr
container_type5: TypeExpr
index_type5: TypeExpr
index_const: int | None
def __init__(
self,
ctx: Context,
sourceref: SourceRef | None,
ret_type5: TypeExpr,
container_type5: TypeExpr,
index_type5: TypeExpr,
index_const: int | None,
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.container_type5 = container_type5
self.index_type5 = index_type5
self.index_const = index_const
def check(self) -> CheckResult:
if not is_concrete(self.container_type5):
return CheckResult(done=False)
da_args = self.ctx.build.type5_is_dynamic_array(self.container_type5)
if da_args is not None:
return combine_check_result([
UnifyTypesConstraint(self.ctx, self.sourceref, da_args, self.ret_type5).check(),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5).check(),
])
sa_args = self.ctx.build.type5_is_static_array(self.container_type5)
if sa_args is not None:
sa_len, sa_typ = sa_args
if self.index_const is not None and (self.index_const < 0 or sa_len <= self.index_const):
return fail('Tuple index out of range')
return combine_check_result([
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, self.ret_type5).check(),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5).check(),
])
tp_args = self.ctx.build.type5_is_tuple(self.container_type5)
if tp_args is not None:
if self.index_const is None:
return fail('Must index with integer literal')
if self.index_const < 0 or len(tp_args) <= self.index_const:
return fail('Tuple index out of range')
return combine_check_result([
UnifyTypesConstraint(self.ctx, self.sourceref, tp_args[self.index_const], self.ret_type5).check(),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5).check(),
])
return fail(f'Missing type class instantation: Subscriptable {self.container_type5.name}')
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.container_type5 = replace_variable(self.container_type5, var, typ)
self.index_type5 = replace_variable(self.index_type5, var, typ)
def __str__(self) -> str:
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
class CanAccessStructMemberConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'struct_type5', 'member_name', )
ret_type5: TypeExpr
struct_type5: TypeExpr
member_name: str
def __init__(
self,
ctx: Context,
sourceref: SourceRef | None,
ret_type5: TypeExpr,
struct_type5: TypeExpr,
member_name: str,
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.struct_type5 = struct_type5
self.member_name = member_name
def check(self) -> CheckResult:
if not is_concrete(self.struct_type5):
return CheckResult(done=False)
st_args = self.ctx.build.type5_is_record(self.struct_type5)
if st_args is None:
return fail('Must be a struct')
member_dict = dict(st_args)
if self.member_name not in member_dict:
return fail('Must have a field with this name')
return UnifyTypesConstraint(self.ctx, self.sourceref, self.ret_type5, member_dict[self.member_name]).check()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
def __str__(self) -> str:
st_args = self.ctx.build.type5_is_record(self.struct_type5)
member_dict = dict(st_args or [])
member_typ = member_dict.get(self.member_name)
if member_typ is None:
expect = 'a -> b'
else:
expect = f'{self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(member_typ)}'
return f".{self.member_name} :: {expect} ~ {self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"