From b2816164f98b3be5ba1157a1971823d98968b328 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 17 Sep 2022 17:14:17 +0200 Subject: [PATCH] Improved unification --- phasm/compiler.py | 8 ++- phasm/ourlang.py | 4 +- phasm/typer.py | 15 +++- phasm/typing.py | 105 ++++++++++++++++++++-------- pylintrc | 2 +- tests/integration/test_constants.py | 16 ++++- 6 files changed, 114 insertions(+), 36 deletions(-) diff --git a/phasm/compiler.py b/phasm/compiler.py index 6505982..e3282bc 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -132,20 +132,25 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: Compile: Any expression """ if isinstance(inp, ourlang.ConstantPrimitive): + assert inp.type_var is not None + stp = typing.simplify(inp.type_var) if stp is None: raise NotImplementedError(f'Constants with type {inp.type_var}') if stp == 'u8': # No native u8 type - treat as i32, with caution + assert isinstance(inp.value, int) wgn.i32.const(inp.value) return if stp in ('i32', 'u32'): + assert isinstance(inp.value, int) wgn.i32.const(inp.value) return if stp in ('i64', 'u64'): + assert isinstance(inp.value, int) wgn.i64.const(inp.value) return @@ -321,7 +326,8 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: """ Compile: Fold expression """ - mtyp = LOAD_STORE_TYPE_MAP.get(inp.base.type.__class__) + assert inp.base.type_var is not None + mtyp = typing.simplify(inp.base.type_var) if mtyp is None: # In the future might extend this by having structs or tuples # as members of struct or tuples diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 6496a2e..30a2527 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -361,11 +361,12 @@ class ModuleConstantDef: """ A constant definition within a module """ - __slots__ = ('name', 'lineno', 'type', 'constant', 'data_block', ) + __slots__ = ('name', 'lineno', 'type', 'type_var', 'constant', 'data_block', ) name: str lineno: int type: TypeBase + type_var: Optional[TypeVar] constant: Constant data_block: Optional['ModuleDataBlock'] @@ -373,6 +374,7 @@ class ModuleConstantDef: self.name = name self.lineno = lineno self.type = type_ + self.type_var = None self.constant = constant self.data_block = data_block diff --git a/phasm/typer.py b/phasm/typer.py index e835c96..169a774 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -40,7 +40,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': return inp.variable.type_var if isinstance(inp, ourlang.BinaryOp): - if inp.operator not in ('+', '-', '|', '&', '^'): + if inp.operator not in ('+', '-', '*', '|', '&', '^'): raise NotImplementedError(expression, inp, inp.operator) left = expression(ctx, inp.left) @@ -55,6 +55,10 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': return inp.function.returns_type_var + if isinstance(inp, ourlang.ModuleConstantReference): + assert inp.definition.type_var is not None + return inp.definition.type_var + raise NotImplementedError(expression, inp) def function(ctx: 'Context', inp: ourlang.Function) -> None: @@ -64,7 +68,11 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None: assert inp.returns_type_var is not None ctx.unify(inp.returns_type_var, typ) - return + +def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None: + inp.type_var = _convert_old_type(ctx, inp.type, inp.name) + constant(ctx, inp.constant) + ctx.unify(inp.type_var, inp.constant.type_var) def module(inp: ourlang.Module) -> None: ctx = Context() @@ -74,6 +82,9 @@ def module(inp: ourlang.Module) -> None: for param in func.posonlyargs: param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}') + for cdef in inp.constant_defs.values(): + module_constant_def(ctx, cdef) + for func in inp.functions.values(): function(ctx, func) diff --git a/phasm/typing.py b/phasm/typing.py index 72a5827..ffd8e16 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -301,62 +301,107 @@ class TypeConstraintBitWidth(TypeConstraintBase): return f'BitWidth={self.minb}..{self.maxb}' class TypeVar: - def __init__(self, ctx: 'Context') -> None: - self.context = ctx - self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {} - self.locations: List[str] = [] + __slots__ = ('ctx', 'ctx_id', ) + + ctx: 'Context' + ctx_id: int + + def __init__(self, ctx: 'Context', ctx_id: int) -> None: + self.ctx = ctx + self.ctx_id = ctx_id def add_constraint(self, newconst: TypeConstraintBase) -> None: - if newconst.__class__ in self.constraints: - self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst) + csts = self.ctx.var_constraints[self.ctx_id] + + if newconst.__class__ in csts: + csts[newconst.__class__] = csts[newconst.__class__].narrow(newconst) else: - self.constraints[newconst.__class__] = newconst + csts[newconst.__class__] = newconst def add_location(self, ref: str) -> None: - self.locations.append(ref) + self.ctx.var_locations[self.ctx_id].append(ref) def __repr__(self) -> str: return ( 'TypeVar<' - + '; '.join(map(repr, self.constraints.values())) + + '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values())) + '; locations: ' - + ', '.join(self.locations) + + ', '.join(self.ctx.var_locations[self.ctx_id]) + '>' ) class Context: + def __init__(self) -> None: + # Variables are unified (or entangled, if you will) + # that means that each TypeVar within a context has an ID, + # and all TypeVars with the same ID are the same TypeVar, + # even if they are a different instance + self.next_ctx_id = 1 + self.vars_by_id: Dict[int, List[TypeVar]] = {} + + # Store the TypeVar properties as a lookup + # so we can update these when unifying + self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {} + self.var_locations: Dict[int, List[str]] = {} + def new_var(self) -> TypeVar: - return TypeVar(self) + ctx_id = self.next_ctx_id + self.next_ctx_id += 1 + + result = TypeVar(self, ctx_id) + + self.vars_by_id[ctx_id] = [result] + self.var_constraints[ctx_id] = {} + self.var_locations[ctx_id] = [] + + return result def unify(self, l: 'TypeVar', r: 'TypeVar') -> None: - newtypevar = self.new_var() + assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return + + # Backup some values that we'll overwrite + l_ctx_id = l.ctx_id + r_ctx_id = r.ctx_id + l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id] + + # Create a new TypeVar, with the combined contraints + # and locations of the old ones + n = self.new_var() try: - for const in l.constraints.values(): - newtypevar.add_constraint(const) - for const in r.constraints.values(): - newtypevar.add_constraint(const) - except TypingNarrowProtoError as ex: - raise TypingNarrowError(l, r, str(ex)) from None + for const in self.var_constraints[l_ctx_id].values(): + n.add_constraint(const) + for const in self.var_constraints[r_ctx_id].values(): + n.add_constraint(const) + except TypingNarrowProtoError as exc: + raise TypingNarrowError(l, r, str(exc)) from None - newtypevar.locations.extend(l.locations) - newtypevar.locations.extend(r.locations) + self.var_locations[n.ctx_id].extend(self.var_locations[l_ctx_id]) + self.var_locations[n.ctx_id].extend(self.var_locations[r_ctx_id]) - # Make pointer locations to the constraints and locations - # so they get linked together throughout the unification + # ## + # And unify (or entangle) the old ones - l.constraints = newtypevar.constraints - l.locations = newtypevar.locations + # First update the IDs, so they all point to the new list + for type_var in l_r_var_list: + type_var.ctx_id = n.ctx_id - r.constraints = newtypevar.constraints - r.locations = newtypevar.locations + # Update our registry of TypeVars by ID, so we can find them + # on the next unify + self.vars_by_id[n.ctx_id].extend(l_r_var_list) - return + # Then delete the old values for the now gone variables + # Do this last, so exceptions thrown in the code above + # still have a valid context + del self.var_constraints[l_ctx_id] + del self.var_constraints[r_ctx_id] + del self.var_locations[l_ctx_id] + del self.var_locations[r_ctx_id] def simplify(inp: TypeVar) -> Optional[str]: - tc_prim = inp.constraints.get(TypeConstraintPrimitive) - tc_bits = inp.constraints.get(TypeConstraintBitWidth) - tc_sign = inp.constraints.get(TypeConstraintSigned) + tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) + tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) + tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned) if tc_prim is None: return None diff --git a/pylintrc b/pylintrc index 0591be3..3759e6b 100644 --- a/pylintrc +++ b/pylintrc @@ -1,5 +1,5 @@ [MASTER] -disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 +disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 max-line-length=180 diff --git a/tests/integration/test_constants.py b/tests/integration/test_constants.py index 19f0203..accf9e2 100644 --- a/tests/integration/test_constants.py +++ b/tests/integration/test_constants.py @@ -3,7 +3,21 @@ import pytest from .helpers import Suite @pytest.mark.integration_test -def test_i32(): +def test_i32_asis(): + code_py = """ +CONSTANT: i32 = 13 + +@exported +def testEntry() -> i32: + return CONSTANT +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +def test_i32_binop(): code_py = """ CONSTANT: i32 = 13