173 lines
5.7 KiB
Python
173 lines
5.7 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Protocol
|
|
|
|
import dataclasses
|
|
|
|
from ..build.base import BuildBase
|
|
from ..ourlang import ConstantTuple, SourceRef
|
|
from ..wasm import WasmTypeFloat32, WasmTypeFloat64, WasmTypeInt32, WasmTypeInt64
|
|
from .kindexpr import KindExpr, Star
|
|
from .typeexpr import (
|
|
TypeApplication,
|
|
TypeExpr,
|
|
TypeVariable,
|
|
is_concrete,
|
|
replace_variable,
|
|
)
|
|
from .unify import Action, ActionList, Failure, ReplaceVariable, unify
|
|
|
|
|
|
class ExpressionProtocol(Protocol):
|
|
"""
|
|
A protocol for classes that should be updated on substitution
|
|
"""
|
|
|
|
type5: TypeExpr | None
|
|
"""
|
|
The type to update
|
|
"""
|
|
|
|
class Context:
|
|
__slots__ = ("build", "placeholder_update", )
|
|
|
|
build: BuildBase[Any]
|
|
placeholder_update: dict[TypeVariable, ExpressionProtocol]
|
|
|
|
def __init__(self, build: BuildBase[Any]) -> None:
|
|
self.build = build
|
|
self.placeholder_update = {}
|
|
|
|
def make_placeholder(self, arg: ExpressionProtocol, kind: KindExpr = Star()) -> TypeVariable:
|
|
res = TypeVariable(kind, f"p_{len(self.placeholder_update)}")
|
|
self.placeholder_update[res] = arg
|
|
return res
|
|
|
|
@dataclasses.dataclass
|
|
class CheckResult:
|
|
_: dataclasses.KW_ONLY
|
|
done: bool = True
|
|
actions: ActionList = dataclasses.field(default_factory=ActionList)
|
|
new_constraints: list[ConstraintBase] = dataclasses.field(default_factory=list)
|
|
failures: list[Failure] = dataclasses.field(default_factory=list)
|
|
|
|
def __str__(self) -> str:
|
|
if not self.done and not self.actions and not self.new_constraints and not self.failures:
|
|
return '(skip for now)'
|
|
|
|
if self.actions and not self.new_constraints and not self.failures and self.done:
|
|
return str(self.actions)
|
|
|
|
return f'{self.actions} {self.new_constraints} {self.failures} {self.done}'
|
|
|
|
class ConstraintBase:
|
|
__slots__ = ("ctx", "sourceref", "comment",)
|
|
|
|
ctx: Context
|
|
sourceref: SourceRef | None
|
|
comment: str | None
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef | None, comment: str | None = None) -> None:
|
|
self.ctx = ctx
|
|
self.sourceref = sourceref
|
|
self.comment = comment
|
|
|
|
def check(self) -> CheckResult:
|
|
raise NotImplementedError(self)
|
|
|
|
def apply(self, action: Action) -> None:
|
|
if isinstance(action, ReplaceVariable):
|
|
self.replace_variable(action.var, action.typ)
|
|
return
|
|
|
|
raise NotImplementedError(action)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
pass
|
|
|
|
|
|
class LiteralFitsConstraint(ConstraintBase):
|
|
__slots__ = ("type", "literal",)
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef | None, type: TypeExpr, literal: Any, *, comment: str | None = None) -> None:
|
|
super().__init__(ctx, sourceref, comment)
|
|
|
|
self.type = type
|
|
self.literal = literal
|
|
|
|
def check(self) -> CheckResult:
|
|
if not is_concrete(self.type):
|
|
return CheckResult(done=False)
|
|
|
|
type_info = self.ctx.build.type_info_map.get(self.type.name)
|
|
|
|
if type_info is not None and (type_info.wasm_type is WasmTypeInt32 or type_info.wasm_type is WasmTypeInt64):
|
|
assert type_info.signed is not None
|
|
|
|
if not isinstance(self.literal.value, int):
|
|
return CheckResult(failures=[Failure('Must be integer')])
|
|
|
|
try:
|
|
self.literal.value.to_bytes(type_info.alloc_size, 'big', signed=type_info.signed)
|
|
except OverflowError:
|
|
return CheckResult(failures=[Failure(f'Must fit in {type_info.alloc_size} byte(s)')])
|
|
|
|
return CheckResult()
|
|
|
|
if type_info is not None and (type_info.wasm_type is WasmTypeFloat32 or type_info.wasm_type is WasmTypeFloat64):
|
|
if isinstance(self.literal.value, float):
|
|
# FIXME: Bit check
|
|
|
|
return CheckResult()
|
|
|
|
return CheckResult(failures=[Failure('Must be real')])
|
|
|
|
da_arg = self.ctx.build.type5_is_dynamic_array(self.type)
|
|
if da_arg is not None:
|
|
if da_arg == self.ctx.build.u8_type5:
|
|
if not isinstance(self.literal.value, bytes):
|
|
return CheckResult(failures=[Failure('Must be bytes')])
|
|
|
|
return CheckResult()
|
|
|
|
raise NotImplementedError(type_info)
|
|
|
|
tp_args = self.ctx.build.type5_is_tuple(self.type)
|
|
if tp_args is not None:
|
|
if not isinstance(self.literal, ConstantTuple):
|
|
return CheckResult(failures=[Failure('Must be tuple')])
|
|
|
|
raise NotImplementedError(type_info)
|
|
|
|
raise NotImplementedError(self.type, type_info)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.type = replace_variable(self.type, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.ctx.build.type5_name(self.type)} can contain {self.literal!r}"
|
|
|
|
class UnifyTypesConstraint(ConstraintBase):
|
|
__slots__ = ("lft", "rgt",)
|
|
|
|
def __init__(self, ctx: Context, sourceref: SourceRef | None, lft: TypeExpr, rgt: TypeExpr, *, comment: str | None = None) -> None:
|
|
super().__init__(ctx, sourceref, comment)
|
|
|
|
self.lft = lft
|
|
self.rgt = rgt
|
|
|
|
def check(self) -> CheckResult:
|
|
result = unify(self.lft, self.rgt)
|
|
|
|
if isinstance(result, Failure):
|
|
return CheckResult(failures=[result])
|
|
|
|
return CheckResult(actions=result)
|
|
|
|
def replace_variable(self, var: TypeVariable, typ: TypeExpr) -> None:
|
|
self.lft = replace_variable(self.lft, var, typ)
|
|
self.rgt = replace_variable(self.rgt, var, typ)
|
|
|
|
def __str__(self) -> str:
|
|
return f"{self.ctx.build.type5_name(self.lft)} ~ {self.ctx.build.type5_name(self.rgt)}"
|