phasm/phasm/type5/constraints.py
Johan B.W. de Vries d1854d7a38 Cleanup to type5 solver
- Replaces weird CheckResult class with proper data classes.
- When a constraint results in new constraints, those are
  treated first
- Minor readability change to function return type
- Code cleanup
- TODO cleanup
- Typo fixes
2025-08-30 14:45:12 +02:00

639 lines
21 KiB
Python

from __future__ import annotations
import dataclasses
from typing import Any, Callable, Iterable, Protocol, Sequence, TypeAlias
from ..build.base import BuildBase
from ..ourlang import SourceRef
from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt64
from .kindexpr import KindExpr, Star
from .record import Record
from .typeexpr import (
AtomicType,
TypeApplication,
TypeConstructor,
TypeExpr,
TypeLevelNat,
TypeVariable,
is_concrete,
occurs,
replace_variable,
)
class ExpressionProtocol(Protocol):
"""
A protocol for classes that should be updated on substitution
"""
type5: TypeExpr | None
"""
The type to update
"""
PolytypeSubsituteMap: TypeAlias = dict[TypeVariable, TypeExpr]
class Context:
__slots__ = ("build", "placeholder_update", "ptst_update", )
build: BuildBase[Any]
placeholder_update: dict[TypeVariable, ExpressionProtocol | None]
ptst_update: dict[TypeVariable, tuple[PolytypeSubsituteMap, TypeVariable]]
def __init__(self, build: BuildBase[Any]) -> None:
self.build = build
self.placeholder_update = {}
self.ptst_update = {}
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star(), prefix: str = 'p') -> TypeVariable:
res = TypeVariable(kind, f"{prefix}_{len(self.placeholder_update)}")
self.placeholder_update[res] = arg
return res
def register_polytype_subsitutes(self, tvar: TypeVariable, arg: PolytypeSubsituteMap, orig_var: TypeVariable) -> None:
"""
When `tvar` gets subsituted, also set the result in arg with orig_var as key
e.g.
(-) :: Callable[a, a, a]
def foo() -> u32:
return 2 - 1
During typing, we instantiate a into a_3, and get the following constraints:
- u8 ~ p_1
- u8 ~ p_2
- Exists NatNum a_3
- Callable[a_3, a_3, a_3] ~ Callable[p_1, p_2, p_0]
- u8 ~ p_0
When we resolve a_3, then on the call to `-`, we should note that a_3 got resolved
to u32. But we need to use `a` as key, since that's what's used on the definition
"""
assert tvar in self.placeholder_update
assert tvar not in self.ptst_update
self.ptst_update[tvar] = (arg, orig_var)
class Success:
"""
The contsraint checks out.
Nothing new was learned and nothing new needs to be checked.
"""
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return '(ok)'
class SkipForNow:
"""
Not enough information to resolve this constraint
"""
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return '(skip for now)'
class ConstraintList(list['ConstraintBase']):
"""
A new list of constraints.
Sometimes, checking one constraint means you get a list of new contraints.
"""
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return f'(got {len(self)} new constraints)'
@dataclasses.dataclass
class Failure:
"""
Both types are already different - cannot be unified.
"""
msg: str
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return f'ERR: {self.msg}'
@dataclasses.dataclass
class ReplaceVariable:
"""
A variable should be replaced.
Either by another variable or by a (concrete) type.
"""
var: TypeVariable
typ: TypeExpr
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
return f'{{{self.var.name} := {type_namer(self.typ)}}}'
CheckResult: TypeAlias = Success | SkipForNow | ConstraintList | Failure | ReplaceVariable
def ok() -> CheckResult:
return Success()
def skip_for_now() -> CheckResult:
return SkipForNow()
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
return ConstraintList(lst)
def fail(msg: str) -> CheckResult:
return Failure(msg)
def replace(var: TypeVariable, typ: TypeExpr) -> CheckResult:
return ReplaceVariable(var, typ)
class ConstraintBase:
__slots__ = ("ctx", "sourceref", )
ctx: Context
sourceref: SourceRef
def __init__(self, ctx: Context, sourceref: SourceRef) -> None:
self.ctx = ctx
self.sourceref = sourceref
def check(self) -> CheckResult:
raise NotImplementedError(self)
def complexity(self) -> int:
raise NotImplementedError
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
pass
class FromLiteralInteger(ConstraintBase):
__slots__ = ('type5', 'literal', )
type5: TypeExpr
literal: int
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: int) -> None:
super().__init__(ctx, sourceref)
self.type5 = type5
self.literal = literal
def check(self) -> CheckResult:
if not is_concrete(self.type5):
return skip_for_now()
type_info = self.ctx.build.type_info_map.get(self.type5.name)
if type_info is None:
return fail('Cannot convert from literal integer')
if type_info.wasm_type is not WasmTypeInt32 and type_info.wasm_type is not WasmTypeInt64:
return fail('Cannot convert from literal integer')
assert type_info.signed is not None # type hint
if not type_info.signed and self.literal < 0:
return fail('May not be negative')
try:
self.literal.to_bytes(type_info.alloc_size, 'big', signed=type_info.signed)
except OverflowError:
return fail(f'Must fit in {type_info.alloc_size} byte(s)')
return ok()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class FromLiteralFloat(ConstraintBase):
__slots__ = ('type5', 'literal', )
type5: TypeExpr
literal: float
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: float) -> None:
super().__init__(ctx, sourceref)
self.type5 = type5
self.literal = literal
def check(self) -> CheckResult:
if not is_concrete(self.type5):
return skip_for_now()
type_info = self.ctx.build.type_info_map.get(self.type5.name)
if type_info is None:
return fail('Cannot convert from literal float')
if type_info.wasm_type is not WasmTypeFloat32 and type_info.wasm_type is not WasmTypeFloat64:
return fail('Cannot convert from literal float')
# TODO: Precision check
return ok()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralFloat {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class FromLiteralBytes(ConstraintBase):
__slots__ = ('type5', 'literal', )
type5: TypeExpr
literal: bytes
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: bytes) -> None:
super().__init__(ctx, sourceref)
self.type5 = type5
self.literal = literal
def check(self) -> CheckResult:
if not is_concrete(self.type5):
return skip_for_now()
da_arg = self.ctx.build.type5_is_dynamic_array(self.type5)
if da_arg is None or da_arg != self.ctx.build.u8_type5:
return fail('Cannot convert from literal bytes')
return ok()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.type5 = replace_variable(self.type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.type5)
def __str__(self) -> str:
return f'FromLiteralBytes {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
class UnifyTypesConstraint(ConstraintBase):
__slots__ = ("lft", "rgt", "prefix", )
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr, prefix: str | None = None) -> None:
super().__init__(ctx, sourceref)
self.lft = lft
self.rgt = rgt
self.prefix = prefix
def check(self) -> CheckResult:
lft = self.lft
rgt = self.rgt
if lft == self.rgt:
return ok()
if lft.kind != rgt.kind:
return fail("Kind mismatch")
if isinstance(lft, AtomicType) and isinstance(rgt, AtomicType):
return fail("Not the same type")
if isinstance(lft, AtomicType) and isinstance(rgt, TypeVariable):
return replace(rgt, lft)
if isinstance(lft, AtomicType) and isinstance(rgt, TypeConstructor):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, AtomicType) and isinstance(rgt, TypeApplication):
return fail("Not the same type" if is_concrete(rgt) else "Type shape mismatch")
if isinstance(lft, TypeVariable) and isinstance(rgt, AtomicType):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeVariable):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeConstructor):
return replace(lft, rgt)
if isinstance(lft, TypeVariable) and isinstance(rgt, TypeApplication):
if occurs(lft, rgt):
return fail("One type occurs in the other")
return replace(lft, rgt)
if isinstance(lft, TypeConstructor) and isinstance(rgt, AtomicType):
raise NotImplementedError # Should have been caught by kind check above
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeVariable):
return replace(rgt, lft)
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeConstructor):
return fail("Not the same type constructor")
if isinstance(lft, TypeConstructor) and isinstance(rgt, TypeApplication):
return fail("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, AtomicType):
return fail("Not the same type" if is_concrete(lft) else "Type shape mismatch")
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeVariable):
if occurs(rgt, lft):
return fail("One type occurs in the other")
return replace(rgt, lft)
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeConstructor):
return fail("Not the same type constructor")
if isinstance(lft, TypeApplication) and isinstance(rgt, TypeApplication):
## USABILITY HACK
## Often, we have two type applications in the same go
## If so, resolve it in a single step
## (Helps with debugging function unification)
## This *should* not affect the actual type unification
## It's just one less call to UnifyTypesConstraint.check
if isinstance(lft.constructor, TypeApplication) and isinstance(rgt.constructor, TypeApplication):
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.constructor, rgt.constructor.constructor),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor.argument, rgt.constructor.argument),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
])
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, lft.constructor, rgt.constructor),
UnifyTypesConstraint(self.ctx, self.sourceref, lft.argument, rgt.argument),
])
raise NotImplementedError(lft, rgt)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.lft = replace_variable(self.lft, var, typ)
self.rgt = replace_variable(self.rgt, var, typ)
def complexity(self) -> int:
return complexity(self.lft) + complexity(self.rgt)
def __str__(self) -> str:
prefix = f'{self.prefix} :: ' if self.prefix else ''
return f"{prefix}{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
class CanBeSubscriptedConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
ret_type5: TypeExpr
container_type5: TypeExpr
index_type5: TypeExpr
index_const: int | None
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
container_type5: TypeExpr,
index_type5: TypeExpr,
index_const: int | None,
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.container_type5 = container_type5
self.index_type5 = index_type5
self.index_const = index_const
def check(self) -> CheckResult:
if not is_concrete(self.container_type5):
return skip_for_now()
tp_args = self.ctx.build.type5_is_tuple(self.container_type5)
if tp_args is not None:
if self.index_const is None:
return fail('Must index with integer literal')
if len(tp_args) <= self.index_const:
return fail('Tuple index out of range')
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, tp_args[self.index_const], self.ret_type5),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
])
if not isinstance(self.container_type5, TypeApplication):
return fail('Missing type class instance')
return new_constraints([
TypeClassInstanceExistsConstraint(
self.ctx,
self.sourceref,
'Subscriptable',
(self.container_type5.constructor, ),
),
UnifyTypesConstraint(self.ctx, self.sourceref, self.container_type5.argument, self.ret_type5),
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
])
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.container_type5 = replace_variable(self.container_type5, var, typ)
self.index_type5 = replace_variable(self.index_type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + complexity(self.container_type5) + complexity(self.index_type5)
def __str__(self) -> str:
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
class CanAccessStructMemberConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'struct_type5', 'member_name', )
ret_type5: TypeExpr
struct_type5: TypeExpr
member_name: str
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
struct_type5: TypeExpr,
member_name: str,
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.struct_type5 = struct_type5
self.member_name = member_name
def check(self) -> CheckResult:
if not is_concrete(self.struct_type5):
return skip_for_now()
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
if st_args is None:
return fail('Must be a struct')
member_dict = dict(st_args)
if self.member_name not in member_dict:
return fail('Must have a field with this name')
return UnifyTypesConstraint(self.ctx, self.sourceref, self.ret_type5, member_dict[self.member_name]).check()
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + complexity(self.struct_type5)
def __str__(self) -> str:
st_args = self.ctx.build.type5_is_struct(self.struct_type5)
member_dict = dict(st_args or [])
member_typ = member_dict.get(self.member_name)
if member_typ is None:
expect = 'a -> b'
else:
expect = f'{self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(member_typ)}'
return f".{self.member_name} :: {expect} ~ {self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
class FromTupleConstraint(ConstraintBase):
__slots__ = ('ret_type5', 'member_type5_list', )
ret_type5: TypeExpr
member_type5_list: list[TypeExpr]
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
ret_type5: TypeExpr,
member_type5_list: Sequence[TypeExpr],
) -> None:
super().__init__(ctx, sourceref)
self.ret_type5 = ret_type5
self.member_type5_list = list(member_type5_list)
def check(self) -> CheckResult:
if not is_concrete(self.ret_type5):
return skip_for_now()
da_arg = self.ctx.build.type5_is_dynamic_array(self.ret_type5)
if da_arg is not None:
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, da_arg, x)
for x in self.member_type5_list
])
sa_args = self.ctx.build.type5_is_static_array(self.ret_type5)
if sa_args is not None:
sa_len, sa_typ = sa_args
if sa_len != len(self.member_type5_list):
return fail('Tuple element count mismatch')
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, x)
for x in self.member_type5_list
])
tp_args = self.ctx.build.type5_is_tuple(self.ret_type5)
if tp_args is not None:
if len(tp_args) != len(self.member_type5_list):
return fail('Tuple element count mismatch')
return new_constraints([
UnifyTypesConstraint(self.ctx, self.sourceref, act_typ, exp_typ)
for act_typ, exp_typ in zip(tp_args, self.member_type5_list, strict=True)
])
raise NotImplementedError(self.ret_type5)
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
self.member_type5_list = [
replace_variable(x, var, typ)
for x in self.member_type5_list
]
def complexity(self) -> int:
return 100 + complexity(self.ret_type5) + sum(complexity(x) for x in self.member_type5_list)
def __str__(self) -> str:
args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list)
return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )'
class TypeClassInstanceExistsConstraint(ConstraintBase):
__slots__ = ('typeclass', 'arg_list', )
typeclass: str
arg_list: list[TypeExpr]
def __init__(
self,
ctx: Context,
sourceref: SourceRef,
typeclass: str,
arg_list: Sequence[TypeExpr]
) -> None:
super().__init__(ctx, sourceref)
self.typeclass = typeclass
self.arg_list = list(arg_list)
def check(self) -> CheckResult:
c_arg_list = [
x for x in self.arg_list if is_concrete(x)
]
if len(c_arg_list) != len(self.arg_list):
return skip_for_now()
if any(isinstance(x, Record) for x in c_arg_list):
# TODO: Allow users to implement type classes on their structs
return fail('Missing type class instance')
key = tuple(c_arg_list)
existing_instances = self.ctx.build.type_class_instances[self.typeclass]
if key in existing_instances:
return ok()
return fail('Missing type class instance')
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
self.arg_list = [
replace_variable(x, var, typ)
for x in self.arg_list
]
def complexity(self) -> int:
return 100 + sum(complexity(x) for x in self.arg_list)
def __str__(self) -> str:
args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list)
return f'Exists {self.typeclass} {args}'
def complexity(expr: TypeExpr) -> int:
if isinstance(expr, AtomicType | TypeLevelNat):
return 1
if isinstance(expr, TypeConstructor):
return 2
if isinstance(expr, TypeVariable):
return 5
if isinstance(expr, TypeApplication):
return complexity(expr.constructor) + complexity(expr.argument)
raise NotImplementedError(expr)