492 lines
16 KiB
Python
492 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from typing import Any, Callable, Iterable, Protocol, Sequence
|
|
|
|
from ..build.base import BuildBase
|
|
from ..ourlang import SourceRef
|
|
from ..type3 import types as type3types
|
|
from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt64
|
|
from .kindexpr import KindExpr, Star
|
|
from .record import Record
|
|
from .typeexpr import (
|
|
AtomicType,
|
|
TypeExpr,
|
|
TypeVariable,
|
|
is_concrete,
|
|
replace_variable,
|
|
)
|
|
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
|
|
|
|
|
|
class ExpressionProtocol(Protocol):
|
|
"""
|
|
A protocol for classes that should be updated on substitution
|
|
"""
|
|
|
|
type5: TypeExpr | None
|
|
"""
|
|
The type to update
|
|
"""
|
|
|
|
class Context:
|
|
__slots__ = ("build", "placeholder_update", )
|
|
|
|
build: BuildBase[Any]
|
|
placeholder_update: dict[TypeVariable, ExpressionProtocol | None]
|
|
|
|
def __init__(self, build: BuildBase[Any]) -> None:
|
|
self.build = build
|
|
self.placeholder_update = {}
|
|
|
|
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star()) -> TypeVariable:
|
|
res = TypeVariable(kind, f"p_{len(self.placeholder_update)}")
|
|
self.placeholder_update[res] = arg
|
|
return res
|
|
|
|
@dataclasses.dataclass
|
|
class CheckResult:
|
|
_: dataclasses.KW_ONLY
|
|
done: bool = True
|
|
actions: ActionList = dataclasses.field(default_factory=ActionList)
|
|
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
|
|
failures: list[Failure] = dataclasses.field(default_factory=list)
|
|
|
|
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
|
if not self.done and not self.actions and not self.new_constraints and not self.failures:
|
|
return '(skip for now)'
|
|
|
|
if self.done and not self.actions and not self.new_constraints and not self.failures:
|
|
return '(ok)'
|
|
|
|
if self.done and self.actions and not self.new_constraints and not self.failures:
|
|
return self.actions.to_str(type_namer)
|
|
|
|
if self.done and not self.actions and self.new_constraints and not self.failures:
|
|
return f'(got {len(self.new_constraints)} new constraints)'
|
|
|
|
if self.done and not self.actions and not self.new_constraints and self.failures:
|
|
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
|
|
|
|
return f'{self.actions.to_str(type_namer)} {self.failures} {self.new_constraints} {self.done}'
|
|
|
|
def skip_for_now() -> CheckResult:
|
|
return CheckResult(done=False)
|
|
|
|
def new_constraints(lst: Iterable[ConstraintBase]) -> CheckResult:
|
|
return CheckResult(new_constraints=list(lst))
|
|
|
|
def ok() -> CheckResult:
|
|
return CheckResult(done=True)
|
|
|
|
def fail(msg: str) -> CheckResult:
|
|
return CheckResult(failures=[Failure(msg)])
|
|
|
|
class ConstraintBase:
|
|
__slots__ = ("ctx", "sourceref", )
|
|
|
|
ctx: Context
|
|
sourceref: SourceRef
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef) -> None:
|
|
self.ctx = ctx
|
|
self.sourceref = sourceref
|
|
|
|
def check(self) -> CheckResult:
|
|
raise NotImplementedError(self)
|
|
|
|
def apply(self, action: Action) -> None:
|
|
if isinstance(action, ReplaceVariable):
|
|
self.replace_variable(action.var, action.typ)
|
|
return
|
|
|
|
raise NotImplementedError(action)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
pass
|
|
|
|
class FromLiteralInteger(ConstraintBase):
|
|
__slots__ = ('type5', 'literal', )
|
|
|
|
type5: TypeExpr
|
|
literal: int
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: int) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.type5 = type5
|
|
self.literal = literal
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.type5):
|
|
return skip_for_now()
|
|
|
|
type_info = self.ctx.build.type_info_map.get(self.type5.name)
|
|
if type_info is None:
|
|
return fail('Cannot convert from literal integer')
|
|
|
|
if type_info.wasm_type is not WasmTypeInt32 and type_info.wasm_type is not WasmTypeInt64:
|
|
return fail('Cannot convert from literal integer')
|
|
|
|
assert type_info.signed is not None # type hint
|
|
|
|
try:
|
|
self.literal.to_bytes(type_info.alloc_size, 'big', signed=type_info.signed)
|
|
except OverflowError:
|
|
return fail(f'Must fit in {type_info.alloc_size} byte(s)')
|
|
|
|
return ok()
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.type5 = replace_variable(self.type5, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
|
|
|
class FromLiteralFloat(ConstraintBase):
|
|
__slots__ = ('type5', 'literal', )
|
|
|
|
type5: TypeExpr
|
|
literal: float
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: float) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.type5 = type5
|
|
self.literal = literal
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.type5):
|
|
return skip_for_now()
|
|
|
|
type_info = self.ctx.build.type_info_map.get(self.type5.name)
|
|
if type_info is None:
|
|
return fail('Cannot convert from literal float')
|
|
|
|
if type_info.wasm_type is not WasmTypeFloat32 and type_info.wasm_type is not WasmTypeFloat64:
|
|
return fail('Cannot convert from literal float')
|
|
|
|
# TODO: Precision check
|
|
|
|
return ok()
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.type5 = replace_variable(self.type5, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f'FromLiteralInteger {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
|
|
|
class FromLiteralBytes(ConstraintBase):
|
|
__slots__ = ('type5', 'literal', )
|
|
|
|
type5: TypeExpr
|
|
literal: bytes
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef, type5: TypeExpr, literal: bytes) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.type5 = type5
|
|
self.literal = literal
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.type5):
|
|
return skip_for_now()
|
|
|
|
da_arg = self.ctx.build.type5_is_dynamic_array(self.type5)
|
|
if da_arg is None or da_arg != self.ctx.build.u8_type5:
|
|
return fail('Cannot convert from literal bytes')
|
|
|
|
return ok()
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.type5 = replace_variable(self.type5, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f'FromLiteralBytes {self.ctx.build.type5_name(self.type5)} ~ {self.literal!r}'
|
|
|
|
class UnifyTypesConstraint(ConstraintBase):
|
|
__slots__ = ("lft", "rgt",)
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef, lft: TypeExpr, rgt: TypeExpr) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.lft = lft
|
|
self.rgt = rgt
|
|
|
|
def check(self) -> CheckResult:
|
|
result = unify(self.lft, self.rgt)
|
|
|
|
if isinstance(result, Failure):
|
|
return CheckResult(failures=[result])
|
|
|
|
return CheckResult(actions=result)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.lft = replace_variable(self.lft, var, typ)
|
|
self.rgt = replace_variable(self.rgt, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
|
|
|
|
class CanBeSubscriptedConstraint(ConstraintBase):
|
|
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
|
|
|
|
ret_type5: TypeExpr
|
|
container_type5: TypeExpr
|
|
index_type5: TypeExpr
|
|
index_const: int | None
|
|
|
|
def __init__(
|
|
self,
|
|
ctx: Context,
|
|
sourceref: SourceRef,
|
|
ret_type5: TypeExpr,
|
|
container_type5: TypeExpr,
|
|
index_type5: TypeExpr,
|
|
index_const: int | None,
|
|
) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.ret_type5 = ret_type5
|
|
self.container_type5 = container_type5
|
|
self.index_type5 = index_type5
|
|
self.index_const = index_const
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.container_type5):
|
|
return CheckResult(done=False)
|
|
|
|
da_arg = self.ctx.build.type5_is_dynamic_array(self.container_type5)
|
|
if da_arg is not None:
|
|
return new_constraints([
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, da_arg, self.ret_type5),
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
|
|
])
|
|
|
|
sa_args = self.ctx.build.type5_is_static_array(self.container_type5)
|
|
if sa_args is not None:
|
|
sa_len, sa_typ = sa_args
|
|
|
|
if self.index_const is not None and (self.index_const < 0 or sa_len <= self.index_const):
|
|
return fail('Tuple index out of range')
|
|
|
|
return new_constraints([
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, self.ret_type5),
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
|
|
])
|
|
|
|
tp_args = self.ctx.build.type5_is_tuple(self.container_type5)
|
|
if tp_args is not None:
|
|
if self.index_const is None:
|
|
return fail('Must index with integer literal')
|
|
|
|
if self.index_const < 0 or len(tp_args) <= self.index_const:
|
|
return fail('Tuple index out of range')
|
|
|
|
return new_constraints([
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, tp_args[self.index_const], self.ret_type5),
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5),
|
|
])
|
|
|
|
return new_constraints([
|
|
TypeClassInstanceExistsConstraint(self.ctx, self.sourceref, 'Subscriptable', [self.container_type5]),
|
|
])
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
|
|
self.container_type5 = replace_variable(self.container_type5, var, typ)
|
|
self.index_type5 = replace_variable(self.index_type5, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
|
|
|
|
class CanAccessStructMemberConstraint(ConstraintBase):
|
|
__slots__ = ('ret_type5', 'struct_type5', 'member_name', )
|
|
|
|
ret_type5: TypeExpr
|
|
struct_type5: TypeExpr
|
|
member_name: str
|
|
|
|
def __init__(
|
|
self,
|
|
ctx: Context,
|
|
sourceref: SourceRef,
|
|
ret_type5: TypeExpr,
|
|
struct_type5: TypeExpr,
|
|
member_name: str,
|
|
) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.ret_type5 = ret_type5
|
|
self.struct_type5 = struct_type5
|
|
self.member_name = member_name
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.struct_type5):
|
|
return CheckResult(done=False)
|
|
|
|
st_args = self.ctx.build.type5_is_record(self.struct_type5)
|
|
if st_args is None:
|
|
return fail('Must be a struct')
|
|
|
|
member_dict = dict(st_args)
|
|
|
|
if self.member_name not in member_dict:
|
|
return fail('Must have a field with this name')
|
|
|
|
return UnifyTypesConstraint(self.ctx, self.sourceref, self.ret_type5, member_dict[self.member_name]).check()
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
|
|
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
st_args = self.ctx.build.type5_is_record(self.struct_type5)
|
|
member_dict = dict(st_args or [])
|
|
member_typ = member_dict.get(self.member_name)
|
|
|
|
if member_typ is None:
|
|
expect = 'a -> b'
|
|
else:
|
|
expect = f'{self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(member_typ)}'
|
|
|
|
return f".{self.member_name} :: {expect} ~ {self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
|
|
|
|
class FromTupleConstraint(ConstraintBase):
|
|
__slots__ = ('ret_type5', 'member_type5_list', )
|
|
|
|
ret_type5: TypeExpr
|
|
member_type5_list: list[TypeExpr]
|
|
|
|
def __init__(
|
|
self,
|
|
ctx: Context,
|
|
sourceref: SourceRef,
|
|
ret_type5: TypeExpr,
|
|
member_type5_list: Sequence[TypeExpr],
|
|
) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.ret_type5 = ret_type5
|
|
self.member_type5_list = list(member_type5_list)
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.ret_type5):
|
|
return CheckResult(done=False)
|
|
|
|
da_arg = self.ctx.build.type5_is_dynamic_array(self.ret_type5)
|
|
if da_arg is not None:
|
|
return CheckResult(new_constraints=[
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, da_arg, x)
|
|
for x in self.member_type5_list
|
|
])
|
|
|
|
sa_args = self.ctx.build.type5_is_static_array(self.ret_type5)
|
|
if sa_args is not None:
|
|
sa_len, sa_typ = sa_args
|
|
if sa_len != len(self.member_type5_list):
|
|
return fail('Tuple element count mismatch')
|
|
|
|
return CheckResult(new_constraints=[
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, x)
|
|
for x in self.member_type5_list
|
|
])
|
|
|
|
tp_args = self.ctx.build.type5_is_tuple(self.ret_type5)
|
|
if tp_args is not None:
|
|
if len(tp_args) != len(self.member_type5_list):
|
|
return fail('Tuple element count mismatch')
|
|
|
|
return CheckResult(new_constraints=[
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, act_typ, exp_typ)
|
|
for act_typ, exp_typ in zip(tp_args, self.member_type5_list, strict=True)
|
|
])
|
|
|
|
raise NotImplementedError(self.ret_type5)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
|
|
self.member_type5_list = [
|
|
replace_variable(x, var, typ)
|
|
for x in self.member_type5_list
|
|
]
|
|
|
|
def __str__(self) -> str:
|
|
args = ', '.join(self.ctx.build.type5_name(x) for x in self.member_type5_list)
|
|
return f'FromTuple {self.ctx.build.type5_name(self.ret_type5)} ~ ({args}, )'
|
|
|
|
class TypeClassInstanceExistsConstraint(ConstraintBase):
|
|
__slots__ = ('typeclass', 'arg_list', )
|
|
|
|
typeclass: str
|
|
arg_list: list[TypeExpr]
|
|
|
|
def __init__(
|
|
self,
|
|
ctx: Context,
|
|
sourceref: SourceRef,
|
|
typeclass: str,
|
|
arg_list: Sequence[TypeExpr]
|
|
) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.typeclass = typeclass
|
|
self.arg_list = list(arg_list)
|
|
|
|
def check(self) -> CheckResult:
|
|
c_arg_list = [
|
|
x for x in self.arg_list if is_concrete(x)
|
|
]
|
|
if len(c_arg_list) != len(self.arg_list):
|
|
return skip_for_now()
|
|
|
|
tcls = self.ctx.build.type_classes[self.typeclass]
|
|
|
|
# Temporary hack while we are converting from type3 to type5
|
|
try:
|
|
targs = tuple(
|
|
_type5_to_type3_or_type3_const(self.ctx.build, x)
|
|
for x in self.arg_list
|
|
)
|
|
except RecordFoundException:
|
|
return fail('Missing type class instance')
|
|
|
|
if (tcls, targs, ) in self.ctx.build.type_class_instances:
|
|
return ok()
|
|
|
|
return fail('Missing type class instance')
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.arg_list = [
|
|
replace_variable(x, var, typ)
|
|
for x in self.arg_list
|
|
]
|
|
|
|
def __str__(self) -> str:
|
|
args = ' '.join(self.ctx.build.type5_name(x) for x in self.arg_list)
|
|
return f'Exists {self.typeclass} {args}'
|
|
|
|
class RecordFoundException(Exception):
|
|
pass
|
|
|
|
def _type5_to_type3_or_type3_const(build: BuildBase[Any], type5: TypeExpr) -> type3types.Type3 | type3types.TypeConstructor_Base[Any] :
|
|
if isinstance(type5, Record):
|
|
raise RecordFoundException
|
|
|
|
if isinstance(type5, AtomicType):
|
|
return build.types[type5.name]
|
|
|
|
da_arg5 = build.type5_is_dynamic_array(type5)
|
|
if da_arg5 is not None:
|
|
return build.dynamic_array
|
|
|
|
sa_arg5 = build.type5_is_static_array(type5)
|
|
if sa_arg5 is not None:
|
|
return build.static_array
|
|
|
|
tp_arg5 = build.type5_is_tuple(type5)
|
|
if tp_arg5 is not None:
|
|
return build.tuple_
|
|
|
|
raise NotImplementedError(type5)
|