356 lines
13 KiB
Python
356 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
from typing import Any, Callable, Iterable, Protocol
|
|
|
|
from ..build.base import BuildBase
|
|
from ..ourlang import ConstantStruct, ConstantTuple, SourceRef
|
|
from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt64
|
|
from .kindexpr import KindExpr, Star
|
|
from .typeexpr import (
|
|
TypeExpr,
|
|
TypeVariable,
|
|
is_concrete,
|
|
replace_variable,
|
|
)
|
|
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
|
|
|
|
|
|
class ExpressionProtocol(Protocol):
|
|
"""
|
|
A protocol for classes that should be updated on substitution
|
|
"""
|
|
|
|
type5: TypeExpr | None
|
|
"""
|
|
The type to update
|
|
"""
|
|
|
|
class Context:
|
|
__slots__ = ("build", "placeholder_update", )
|
|
|
|
build: BuildBase[Any]
|
|
placeholder_update: dict[TypeVariable, ExpressionProtocol]
|
|
|
|
def __init__(self, build: BuildBase[Any]) -> None:
|
|
self.build = build
|
|
self.placeholder_update = {}
|
|
|
|
def make_placeholder(self, arg: ExpressionProtocol, kind: KindExpr = Star()) -> TypeVariable:
|
|
res = TypeVariable(kind, f"p_{len(self.placeholder_update)}")
|
|
self.placeholder_update[res] = arg
|
|
return res
|
|
|
|
@dataclasses.dataclass
|
|
class CheckResult:
|
|
_: dataclasses.KW_ONLY
|
|
done: bool = True
|
|
actions: ActionList = dataclasses.field(default_factory=ActionList)
|
|
failures: list[Failure] = dataclasses.field(default_factory=list)
|
|
|
|
def to_str(self, type_namer: Callable[[TypeExpr], str]) -> str:
|
|
if not self.done and not self.actions and not self.failures:
|
|
return '(skip for now)'
|
|
|
|
if self.done and not self.actions and not self.failures:
|
|
return '(ok)'
|
|
|
|
if self.done and self.actions and not self.failures:
|
|
return self.actions.to_str(type_namer)
|
|
|
|
if self.done and not self.actions and self.failures:
|
|
return 'ERR: ' + '; '.join(x.msg for x in self.failures)
|
|
|
|
return f'{self.actions.to_str(type_namer)} {self.failures} {self.done}'
|
|
|
|
def fail(msg: str) -> CheckResult:
|
|
return CheckResult(failures=[Failure(msg)])
|
|
|
|
def combine_check_result(cr_list: Iterable[CheckResult]) -> CheckResult:
|
|
done = True
|
|
actions = ActionList()
|
|
failures = list()
|
|
|
|
for cr in cr_list:
|
|
done = done and cr.done
|
|
actions.extend(cr.actions)
|
|
failures.extend(cr.failures)
|
|
|
|
return CheckResult(done=done, actions=actions, failures=failures)
|
|
|
|
class ConstraintBase:
|
|
__slots__ = ("ctx", "sourceref", "comment",)
|
|
|
|
ctx: Context
|
|
sourceref: SourceRef | None
|
|
comment: str | None
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef | None, comment: str | None = None) -> None:
|
|
self.ctx = ctx
|
|
self.sourceref = sourceref
|
|
self.comment = comment
|
|
|
|
def check(self) -> CheckResult:
|
|
raise NotImplementedError(self)
|
|
|
|
def apply(self, action: Action) -> None:
|
|
if isinstance(action, ReplaceVariable):
|
|
self.replace_variable(action.var, action.typ)
|
|
return
|
|
|
|
raise NotImplementedError(action)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
pass
|
|
|
|
|
|
class LiteralFitsConstraint(ConstraintBase):
|
|
__slots__ = ("type", "literal",)
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef | None, type: TypeExpr, literal: Any, *, comment: str | None = None) -> None:
|
|
super().__init__(ctx, sourceref, comment)
|
|
|
|
self.type = type
|
|
self.literal = literal
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.type):
|
|
return CheckResult(done=False)
|
|
|
|
type_info = self.ctx.build.type_info_map.get(self.type.name)
|
|
|
|
if type_info is not None and (type_info.wasm_type is WasmTypeInt32 or type_info.wasm_type is WasmTypeInt64):
|
|
assert type_info.signed is not None
|
|
|
|
if not isinstance(self.literal.value, int):
|
|
return fail('Must be integer')
|
|
|
|
try:
|
|
self.literal.value.to_bytes(type_info.alloc_size, 'big', signed=type_info.signed)
|
|
except OverflowError:
|
|
return fail(f'Must fit in {type_info.alloc_size} byte(s)')
|
|
|
|
return CheckResult()
|
|
|
|
if type_info is not None and (type_info.wasm_type is WasmTypeFloat32 or type_info.wasm_type is WasmTypeFloat64):
|
|
if isinstance(self.literal.value, float):
|
|
# FIXME: Bit check
|
|
|
|
return CheckResult()
|
|
|
|
return fail('Must be real')
|
|
|
|
da_arg = self.ctx.build.type5_is_dynamic_array(self.type)
|
|
if da_arg is not None:
|
|
if da_arg == self.ctx.build.u8_type5:
|
|
if not isinstance(self.literal.value, bytes):
|
|
return fail('Must be bytes')
|
|
|
|
return CheckResult()
|
|
|
|
if not isinstance(self.literal, ConstantTuple):
|
|
return fail('Must be tuple')
|
|
|
|
return combine_check_result((
|
|
LiteralFitsConstraint(self.ctx, nod.sourceref, da_arg, nod).check()
|
|
for nod in self.literal.value
|
|
))
|
|
|
|
sa_args = self.ctx.build.type5_is_static_array(self.type)
|
|
if sa_args is not None:
|
|
sa_len, sa_typ = sa_args
|
|
|
|
if not isinstance(self.literal, ConstantTuple):
|
|
return fail('Must be tuple')
|
|
|
|
if len(self.literal.value) != sa_len:
|
|
return fail('Tuple element count mismatch')
|
|
|
|
return combine_check_result((
|
|
LiteralFitsConstraint(self.ctx, nod.sourceref, sa_typ, nod).check()
|
|
for nod in self.literal.value
|
|
))
|
|
|
|
st_args = self.ctx.build.type5_is_record(self.type)
|
|
if st_args is not None:
|
|
if not isinstance(self.literal, ConstantStruct):
|
|
return fail('Must be struct')
|
|
|
|
if self.literal.struct_type3.name != self.type.name: # TODO: Name based check is wonky
|
|
return fail('Must be right struct')
|
|
|
|
if len(self.literal.value) != len(st_args):
|
|
return fail('Struct member count mismatch')
|
|
|
|
return combine_check_result((
|
|
LiteralFitsConstraint(self.ctx, nod.sourceref, nod_typ, nod).check()
|
|
for nod, (_, nod_typ) in zip(self.literal.value, st_args, strict=True)
|
|
))
|
|
|
|
tp_args = self.ctx.build.type5_is_tuple(self.type)
|
|
if tp_args is not None:
|
|
if not isinstance(self.literal, ConstantTuple):
|
|
return fail('Must be tuple')
|
|
|
|
if len(self.literal.value) != len(tp_args):
|
|
return fail('Tuple element count mismatch')
|
|
|
|
return combine_check_result((
|
|
LiteralFitsConstraint(self.ctx, nod.sourceref, nod_typ, nod).check()
|
|
for nod, nod_typ in zip(self.literal.value, tp_args, strict=True)
|
|
))
|
|
|
|
raise NotImplementedError(self.type, type_info)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.type = replace_variable(self.type, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.ctx.build.type5_name(self.type)} can contain {self.literal!r}"
|
|
|
|
class UnifyTypesConstraint(ConstraintBase):
|
|
__slots__ = ("lft", "rgt",)
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef | None, lft: TypeExpr, rgt: TypeExpr, *, comment: str | None = None) -> None:
|
|
super().__init__(ctx, sourceref, comment)
|
|
|
|
self.lft = lft
|
|
self.rgt = rgt
|
|
|
|
def check(self) -> CheckResult:
|
|
result = unify(self.lft, self.rgt)
|
|
|
|
if isinstance(result, Failure):
|
|
return CheckResult(failures=[result])
|
|
|
|
return CheckResult(actions=result)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.lft = replace_variable(self.lft, var, typ)
|
|
self.rgt = replace_variable(self.rgt, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
|
|
|
|
class CanBeSubscriptedConstraint(ConstraintBase):
|
|
__slots__ = ('ret_type5', 'container_type5', 'index_type5', 'index_const', )
|
|
|
|
ret_type5: TypeExpr
|
|
container_type5: TypeExpr
|
|
index_type5: TypeExpr
|
|
index_const: int | None
|
|
|
|
def __init__(
|
|
self,
|
|
ctx: Context,
|
|
sourceref: SourceRef | None,
|
|
ret_type5: TypeExpr,
|
|
container_type5: TypeExpr,
|
|
index_type5: TypeExpr,
|
|
index_const: int | None,
|
|
) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.ret_type5 = ret_type5
|
|
self.container_type5 = container_type5
|
|
self.index_type5 = index_type5
|
|
self.index_const = index_const
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.container_type5):
|
|
return CheckResult(done=False)
|
|
|
|
da_args = self.ctx.build.type5_is_dynamic_array(self.container_type5)
|
|
if da_args is not None:
|
|
return combine_check_result([
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, da_args, self.ret_type5).check(),
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5).check(),
|
|
])
|
|
|
|
sa_args = self.ctx.build.type5_is_static_array(self.container_type5)
|
|
if sa_args is not None:
|
|
sa_len, sa_typ = sa_args
|
|
|
|
if self.index_const is not None and (self.index_const < 0 or sa_len <= self.index_const):
|
|
return fail('Tuple index out of range')
|
|
|
|
return combine_check_result([
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, sa_typ, self.ret_type5).check(),
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5).check(),
|
|
])
|
|
|
|
tp_args = self.ctx.build.type5_is_tuple(self.container_type5)
|
|
if tp_args is not None:
|
|
if self.index_const is None:
|
|
return fail('Must index with integer literal')
|
|
|
|
if self.index_const < 0 or len(tp_args) <= self.index_const:
|
|
return fail('Tuple index out of range')
|
|
|
|
return combine_check_result([
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, tp_args[self.index_const], self.ret_type5).check(),
|
|
UnifyTypesConstraint(self.ctx, self.sourceref, self.ctx.build.u32_type5, self.index_type5).check(),
|
|
])
|
|
|
|
return fail(f'Missing type class instantation: Subscriptable {self.container_type5.name}')
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
|
|
self.container_type5 = replace_variable(self.container_type5, var, typ)
|
|
self.index_type5 = replace_variable(self.index_type5, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f"[] :: t a -> b -> a ~ {self.ctx.build.type5_name(self.container_type5)} -> {self.ctx.build.type5_name(self.index_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
|
|
|
|
class CanAccessStructMemberConstraint(ConstraintBase):
|
|
__slots__ = ('ret_type5', 'struct_type5', 'member_name', )
|
|
|
|
ret_type5: TypeExpr
|
|
struct_type5: TypeExpr
|
|
member_name: str
|
|
|
|
def __init__(
|
|
self,
|
|
ctx: Context,
|
|
sourceref: SourceRef | None,
|
|
ret_type5: TypeExpr,
|
|
struct_type5: TypeExpr,
|
|
member_name: str,
|
|
) -> None:
|
|
super().__init__(ctx, sourceref)
|
|
|
|
self.ret_type5 = ret_type5
|
|
self.struct_type5 = struct_type5
|
|
self.member_name = member_name
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.struct_type5):
|
|
return CheckResult(done=False)
|
|
|
|
st_args = self.ctx.build.type5_is_record(self.struct_type5)
|
|
if st_args is None:
|
|
return fail('Must be a struct')
|
|
|
|
member_dict = dict(st_args)
|
|
|
|
if self.member_name not in member_dict:
|
|
return fail('Must have a field with this name')
|
|
|
|
return UnifyTypesConstraint(self.ctx, self.sourceref, self.ret_type5, member_dict[self.member_name]).check()
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.ret_type5 = replace_variable(self.ret_type5, var, typ)
|
|
self.struct_type5 = replace_variable(self.struct_type5, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
st_args = self.ctx.build.type5_is_record(self.struct_type5)
|
|
member_dict = dict(st_args or [])
|
|
member_typ = member_dict.get(self.member_name)
|
|
|
|
if member_typ is None:
|
|
expect = 'a -> b'
|
|
else:
|
|
expect = f'{self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(member_typ)}'
|
|
|
|
return f".{self.member_name} :: {expect} ~ {self.ctx.build.type5_name(self.struct_type5)} -> {self.ctx.build.type5_name(self.ret_type5)}"
|