Improved unification

This commit is contained in:
Johan B.W. de Vries 2022-09-17 17:14:17 +02:00
parent 6f3d9a5bcc
commit b2816164f9
6 changed files with 114 additions and 36 deletions

View File

@ -132,20 +132,25 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
Compile: Any expression Compile: Any expression
""" """
if isinstance(inp, ourlang.ConstantPrimitive): if isinstance(inp, ourlang.ConstantPrimitive):
assert inp.type_var is not None
stp = typing.simplify(inp.type_var) stp = typing.simplify(inp.type_var)
if stp is None: if stp is None:
raise NotImplementedError(f'Constants with type {inp.type_var}') raise NotImplementedError(f'Constants with type {inp.type_var}')
if stp == 'u8': if stp == 'u8':
# No native u8 type - treat as i32, with caution # No native u8 type - treat as i32, with caution
assert isinstance(inp.value, int)
wgn.i32.const(inp.value) wgn.i32.const(inp.value)
return return
if stp in ('i32', 'u32'): if stp in ('i32', 'u32'):
assert isinstance(inp.value, int)
wgn.i32.const(inp.value) wgn.i32.const(inp.value)
return return
if stp in ('i64', 'u64'): if stp in ('i64', 'u64'):
assert isinstance(inp.value, int)
wgn.i64.const(inp.value) wgn.i64.const(inp.value)
return return
@ -321,7 +326,8 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
""" """
Compile: Fold expression Compile: Fold expression
""" """
mtyp = LOAD_STORE_TYPE_MAP.get(inp.base.type.__class__) assert inp.base.type_var is not None
mtyp = typing.simplify(inp.base.type_var)
if mtyp is None: if mtyp is None:
# In the future might extend this by having structs or tuples # In the future might extend this by having structs or tuples
# as members of struct or tuples # as members of struct or tuples

View File

@ -361,11 +361,12 @@ class ModuleConstantDef:
""" """
A constant definition within a module A constant definition within a module
""" """
__slots__ = ('name', 'lineno', 'type', 'constant', 'data_block', ) __slots__ = ('name', 'lineno', 'type', 'type_var', 'constant', 'data_block', )
name: str name: str
lineno: int lineno: int
type: TypeBase type: TypeBase
type_var: Optional[TypeVar]
constant: Constant constant: Constant
data_block: Optional['ModuleDataBlock'] data_block: Optional['ModuleDataBlock']
@ -373,6 +374,7 @@ class ModuleConstantDef:
self.name = name self.name = name
self.lineno = lineno self.lineno = lineno
self.type = type_ self.type = type_
self.type_var = None
self.constant = constant self.constant = constant
self.data_block = data_block self.data_block = data_block

View File

@ -40,7 +40,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
return inp.variable.type_var return inp.variable.type_var
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
if inp.operator not in ('+', '-', '|', '&', '^'): if inp.operator not in ('+', '-', '*', '|', '&', '^'):
raise NotImplementedError(expression, inp, inp.operator) raise NotImplementedError(expression, inp, inp.operator)
left = expression(ctx, inp.left) left = expression(ctx, inp.left)
@ -55,6 +55,10 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
return inp.function.returns_type_var return inp.function.returns_type_var
if isinstance(inp, ourlang.ModuleConstantReference):
assert inp.definition.type_var is not None
return inp.definition.type_var
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def function(ctx: 'Context', inp: ourlang.Function) -> None: def function(ctx: 'Context', inp: ourlang.Function) -> None:
@ -64,7 +68,11 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None:
assert inp.returns_type_var is not None assert inp.returns_type_var is not None
ctx.unify(inp.returns_type_var, typ) ctx.unify(inp.returns_type_var, typ)
return
def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None:
inp.type_var = _convert_old_type(ctx, inp.type, inp.name)
constant(ctx, inp.constant)
ctx.unify(inp.type_var, inp.constant.type_var)
def module(inp: ourlang.Module) -> None: def module(inp: ourlang.Module) -> None:
ctx = Context() ctx = Context()
@ -74,6 +82,9 @@ def module(inp: ourlang.Module) -> None:
for param in func.posonlyargs: for param in func.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}') param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
for cdef in inp.constant_defs.values():
module_constant_def(ctx, cdef)
for func in inp.functions.values(): for func in inp.functions.values():
function(ctx, func) function(ctx, func)

View File

@ -301,62 +301,107 @@ class TypeConstraintBitWidth(TypeConstraintBase):
return f'BitWidth={self.minb}..{self.maxb}' return f'BitWidth={self.minb}..{self.maxb}'
class TypeVar: class TypeVar:
def __init__(self, ctx: 'Context') -> None: __slots__ = ('ctx', 'ctx_id', )
self.context = ctx
self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {} ctx: 'Context'
self.locations: List[str] = [] ctx_id: int
def __init__(self, ctx: 'Context', ctx_id: int) -> None:
self.ctx = ctx
self.ctx_id = ctx_id
def add_constraint(self, newconst: TypeConstraintBase) -> None: def add_constraint(self, newconst: TypeConstraintBase) -> None:
if newconst.__class__ in self.constraints: csts = self.ctx.var_constraints[self.ctx_id]
self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst)
if newconst.__class__ in csts:
csts[newconst.__class__] = csts[newconst.__class__].narrow(newconst)
else: else:
self.constraints[newconst.__class__] = newconst csts[newconst.__class__] = newconst
def add_location(self, ref: str) -> None: def add_location(self, ref: str) -> None:
self.locations.append(ref) self.ctx.var_locations[self.ctx_id].append(ref)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
'TypeVar<' 'TypeVar<'
+ '; '.join(map(repr, self.constraints.values())) + '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
+ '; locations: ' + '; locations: '
+ ', '.join(self.locations) + ', '.join(self.ctx.var_locations[self.ctx_id])
+ '>' + '>'
) )
class Context: class Context:
def __init__(self) -> None:
# Variables are unified (or entangled, if you will)
# that means that each TypeVar within a context has an ID,
# and all TypeVars with the same ID are the same TypeVar,
# even if they are a different instance
self.next_ctx_id = 1
self.vars_by_id: Dict[int, List[TypeVar]] = {}
# Store the TypeVar properties as a lookup
# so we can update these when unifying
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
self.var_locations: Dict[int, List[str]] = {}
def new_var(self) -> TypeVar: def new_var(self) -> TypeVar:
return TypeVar(self) ctx_id = self.next_ctx_id
self.next_ctx_id += 1
result = TypeVar(self, ctx_id)
self.vars_by_id[ctx_id] = [result]
self.var_constraints[ctx_id] = {}
self.var_locations[ctx_id] = []
return result
def unify(self, l: 'TypeVar', r: 'TypeVar') -> None: def unify(self, l: 'TypeVar', r: 'TypeVar') -> None:
newtypevar = self.new_var() assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return
# Backup some values that we'll overwrite
l_ctx_id = l.ctx_id
r_ctx_id = r.ctx_id
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
# Create a new TypeVar, with the combined contraints
# and locations of the old ones
n = self.new_var()
try: try:
for const in l.constraints.values(): for const in self.var_constraints[l_ctx_id].values():
newtypevar.add_constraint(const) n.add_constraint(const)
for const in r.constraints.values(): for const in self.var_constraints[r_ctx_id].values():
newtypevar.add_constraint(const) n.add_constraint(const)
except TypingNarrowProtoError as ex: except TypingNarrowProtoError as exc:
raise TypingNarrowError(l, r, str(ex)) from None raise TypingNarrowError(l, r, str(exc)) from None
newtypevar.locations.extend(l.locations) self.var_locations[n.ctx_id].extend(self.var_locations[l_ctx_id])
newtypevar.locations.extend(r.locations) self.var_locations[n.ctx_id].extend(self.var_locations[r_ctx_id])
# Make pointer locations to the constraints and locations # ##
# so they get linked together throughout the unification # And unify (or entangle) the old ones
l.constraints = newtypevar.constraints # First update the IDs, so they all point to the new list
l.locations = newtypevar.locations for type_var in l_r_var_list:
type_var.ctx_id = n.ctx_id
r.constraints = newtypevar.constraints # Update our registry of TypeVars by ID, so we can find them
r.locations = newtypevar.locations # on the next unify
self.vars_by_id[n.ctx_id].extend(l_r_var_list)
return # Then delete the old values for the now gone variables
# Do this last, so exceptions thrown in the code above
# still have a valid context
del self.var_constraints[l_ctx_id]
del self.var_constraints[r_ctx_id]
del self.var_locations[l_ctx_id]
del self.var_locations[r_ctx_id]
def simplify(inp: TypeVar) -> Optional[str]: def simplify(inp: TypeVar) -> Optional[str]:
tc_prim = inp.constraints.get(TypeConstraintPrimitive) tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive)
tc_bits = inp.constraints.get(TypeConstraintBitWidth) tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth)
tc_sign = inp.constraints.get(TypeConstraintSigned) tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned)
if tc_prim is None: if tc_prim is None:
return None return None

View File

@ -1,5 +1,5 @@
[MASTER] [MASTER]
disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223
max-line-length=180 max-line-length=180

View File

@ -3,7 +3,21 @@ import pytest
from .helpers import Suite from .helpers import Suite
@pytest.mark.integration_test @pytest.mark.integration_test
def test_i32(): def test_i32_asis():
code_py = """
CONSTANT: i32 = 13
@exported
def testEntry() -> i32:
return CONSTANT
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
@pytest.mark.integration_test
def test_i32_binop():
code_py = """ code_py = """
CONSTANT: i32 = 13 CONSTANT: i32 = 13