Improved unification
This commit is contained in:
parent
6f3d9a5bcc
commit
b2816164f9
@ -132,20 +132,25 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
|||||||
Compile: Any expression
|
Compile: Any expression
|
||||||
"""
|
"""
|
||||||
if isinstance(inp, ourlang.ConstantPrimitive):
|
if isinstance(inp, ourlang.ConstantPrimitive):
|
||||||
|
assert inp.type_var is not None
|
||||||
|
|
||||||
stp = typing.simplify(inp.type_var)
|
stp = typing.simplify(inp.type_var)
|
||||||
if stp is None:
|
if stp is None:
|
||||||
raise NotImplementedError(f'Constants with type {inp.type_var}')
|
raise NotImplementedError(f'Constants with type {inp.type_var}')
|
||||||
|
|
||||||
if stp == 'u8':
|
if stp == 'u8':
|
||||||
# No native u8 type - treat as i32, with caution
|
# No native u8 type - treat as i32, with caution
|
||||||
|
assert isinstance(inp.value, int)
|
||||||
wgn.i32.const(inp.value)
|
wgn.i32.const(inp.value)
|
||||||
return
|
return
|
||||||
|
|
||||||
if stp in ('i32', 'u32'):
|
if stp in ('i32', 'u32'):
|
||||||
|
assert isinstance(inp.value, int)
|
||||||
wgn.i32.const(inp.value)
|
wgn.i32.const(inp.value)
|
||||||
return
|
return
|
||||||
|
|
||||||
if stp in ('i64', 'u64'):
|
if stp in ('i64', 'u64'):
|
||||||
|
assert isinstance(inp.value, int)
|
||||||
wgn.i64.const(inp.value)
|
wgn.i64.const(inp.value)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -321,7 +326,8 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
|
|||||||
"""
|
"""
|
||||||
Compile: Fold expression
|
Compile: Fold expression
|
||||||
"""
|
"""
|
||||||
mtyp = LOAD_STORE_TYPE_MAP.get(inp.base.type.__class__)
|
assert inp.base.type_var is not None
|
||||||
|
mtyp = typing.simplify(inp.base.type_var)
|
||||||
if mtyp is None:
|
if mtyp is None:
|
||||||
# In the future might extend this by having structs or tuples
|
# In the future might extend this by having structs or tuples
|
||||||
# as members of struct or tuples
|
# as members of struct or tuples
|
||||||
|
|||||||
@ -361,11 +361,12 @@ class ModuleConstantDef:
|
|||||||
"""
|
"""
|
||||||
A constant definition within a module
|
A constant definition within a module
|
||||||
"""
|
"""
|
||||||
__slots__ = ('name', 'lineno', 'type', 'constant', 'data_block', )
|
__slots__ = ('name', 'lineno', 'type', 'type_var', 'constant', 'data_block', )
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
lineno: int
|
lineno: int
|
||||||
type: TypeBase
|
type: TypeBase
|
||||||
|
type_var: Optional[TypeVar]
|
||||||
constant: Constant
|
constant: Constant
|
||||||
data_block: Optional['ModuleDataBlock']
|
data_block: Optional['ModuleDataBlock']
|
||||||
|
|
||||||
@ -373,6 +374,7 @@ class ModuleConstantDef:
|
|||||||
self.name = name
|
self.name = name
|
||||||
self.lineno = lineno
|
self.lineno = lineno
|
||||||
self.type = type_
|
self.type = type_
|
||||||
|
self.type_var = None
|
||||||
self.constant = constant
|
self.constant = constant
|
||||||
self.data_block = data_block
|
self.data_block = data_block
|
||||||
|
|
||||||
|
|||||||
@ -40,7 +40,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
|
|||||||
return inp.variable.type_var
|
return inp.variable.type_var
|
||||||
|
|
||||||
if isinstance(inp, ourlang.BinaryOp):
|
if isinstance(inp, ourlang.BinaryOp):
|
||||||
if inp.operator not in ('+', '-', '|', '&', '^'):
|
if inp.operator not in ('+', '-', '*', '|', '&', '^'):
|
||||||
raise NotImplementedError(expression, inp, inp.operator)
|
raise NotImplementedError(expression, inp, inp.operator)
|
||||||
|
|
||||||
left = expression(ctx, inp.left)
|
left = expression(ctx, inp.left)
|
||||||
@ -55,6 +55,10 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
|
|||||||
|
|
||||||
return inp.function.returns_type_var
|
return inp.function.returns_type_var
|
||||||
|
|
||||||
|
if isinstance(inp, ourlang.ModuleConstantReference):
|
||||||
|
assert inp.definition.type_var is not None
|
||||||
|
return inp.definition.type_var
|
||||||
|
|
||||||
raise NotImplementedError(expression, inp)
|
raise NotImplementedError(expression, inp)
|
||||||
|
|
||||||
def function(ctx: 'Context', inp: ourlang.Function) -> None:
|
def function(ctx: 'Context', inp: ourlang.Function) -> None:
|
||||||
@ -64,7 +68,11 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None:
|
|||||||
|
|
||||||
assert inp.returns_type_var is not None
|
assert inp.returns_type_var is not None
|
||||||
ctx.unify(inp.returns_type_var, typ)
|
ctx.unify(inp.returns_type_var, typ)
|
||||||
return
|
|
||||||
|
def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None:
|
||||||
|
inp.type_var = _convert_old_type(ctx, inp.type, inp.name)
|
||||||
|
constant(ctx, inp.constant)
|
||||||
|
ctx.unify(inp.type_var, inp.constant.type_var)
|
||||||
|
|
||||||
def module(inp: ourlang.Module) -> None:
|
def module(inp: ourlang.Module) -> None:
|
||||||
ctx = Context()
|
ctx = Context()
|
||||||
@ -74,6 +82,9 @@ def module(inp: ourlang.Module) -> None:
|
|||||||
for param in func.posonlyargs:
|
for param in func.posonlyargs:
|
||||||
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
|
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
|
||||||
|
|
||||||
|
for cdef in inp.constant_defs.values():
|
||||||
|
module_constant_def(ctx, cdef)
|
||||||
|
|
||||||
for func in inp.functions.values():
|
for func in inp.functions.values():
|
||||||
function(ctx, func)
|
function(ctx, func)
|
||||||
|
|
||||||
|
|||||||
105
phasm/typing.py
105
phasm/typing.py
@ -301,62 +301,107 @@ class TypeConstraintBitWidth(TypeConstraintBase):
|
|||||||
return f'BitWidth={self.minb}..{self.maxb}'
|
return f'BitWidth={self.minb}..{self.maxb}'
|
||||||
|
|
||||||
class TypeVar:
|
class TypeVar:
|
||||||
def __init__(self, ctx: 'Context') -> None:
|
__slots__ = ('ctx', 'ctx_id', )
|
||||||
self.context = ctx
|
|
||||||
self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {}
|
ctx: 'Context'
|
||||||
self.locations: List[str] = []
|
ctx_id: int
|
||||||
|
|
||||||
|
def __init__(self, ctx: 'Context', ctx_id: int) -> None:
|
||||||
|
self.ctx = ctx
|
||||||
|
self.ctx_id = ctx_id
|
||||||
|
|
||||||
def add_constraint(self, newconst: TypeConstraintBase) -> None:
|
def add_constraint(self, newconst: TypeConstraintBase) -> None:
|
||||||
if newconst.__class__ in self.constraints:
|
csts = self.ctx.var_constraints[self.ctx_id]
|
||||||
self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst)
|
|
||||||
|
if newconst.__class__ in csts:
|
||||||
|
csts[newconst.__class__] = csts[newconst.__class__].narrow(newconst)
|
||||||
else:
|
else:
|
||||||
self.constraints[newconst.__class__] = newconst
|
csts[newconst.__class__] = newconst
|
||||||
|
|
||||||
def add_location(self, ref: str) -> None:
|
def add_location(self, ref: str) -> None:
|
||||||
self.locations.append(ref)
|
self.ctx.var_locations[self.ctx_id].append(ref)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
'TypeVar<'
|
'TypeVar<'
|
||||||
+ '; '.join(map(repr, self.constraints.values()))
|
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
|
||||||
+ '; locations: '
|
+ '; locations: '
|
||||||
+ ', '.join(self.locations)
|
+ ', '.join(self.ctx.var_locations[self.ctx_id])
|
||||||
+ '>'
|
+ '>'
|
||||||
)
|
)
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
# Variables are unified (or entangled, if you will)
|
||||||
|
# that means that each TypeVar within a context has an ID,
|
||||||
|
# and all TypeVars with the same ID are the same TypeVar,
|
||||||
|
# even if they are a different instance
|
||||||
|
self.next_ctx_id = 1
|
||||||
|
self.vars_by_id: Dict[int, List[TypeVar]] = {}
|
||||||
|
|
||||||
|
# Store the TypeVar properties as a lookup
|
||||||
|
# so we can update these when unifying
|
||||||
|
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
|
||||||
|
self.var_locations: Dict[int, List[str]] = {}
|
||||||
|
|
||||||
def new_var(self) -> TypeVar:
|
def new_var(self) -> TypeVar:
|
||||||
return TypeVar(self)
|
ctx_id = self.next_ctx_id
|
||||||
|
self.next_ctx_id += 1
|
||||||
|
|
||||||
|
result = TypeVar(self, ctx_id)
|
||||||
|
|
||||||
|
self.vars_by_id[ctx_id] = [result]
|
||||||
|
self.var_constraints[ctx_id] = {}
|
||||||
|
self.var_locations[ctx_id] = []
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def unify(self, l: 'TypeVar', r: 'TypeVar') -> None:
|
def unify(self, l: 'TypeVar', r: 'TypeVar') -> None:
|
||||||
newtypevar = self.new_var()
|
assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return
|
||||||
|
|
||||||
|
# Backup some values that we'll overwrite
|
||||||
|
l_ctx_id = l.ctx_id
|
||||||
|
r_ctx_id = r.ctx_id
|
||||||
|
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
|
||||||
|
|
||||||
|
# Create a new TypeVar, with the combined contraints
|
||||||
|
# and locations of the old ones
|
||||||
|
n = self.new_var()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for const in l.constraints.values():
|
for const in self.var_constraints[l_ctx_id].values():
|
||||||
newtypevar.add_constraint(const)
|
n.add_constraint(const)
|
||||||
for const in r.constraints.values():
|
for const in self.var_constraints[r_ctx_id].values():
|
||||||
newtypevar.add_constraint(const)
|
n.add_constraint(const)
|
||||||
except TypingNarrowProtoError as ex:
|
except TypingNarrowProtoError as exc:
|
||||||
raise TypingNarrowError(l, r, str(ex)) from None
|
raise TypingNarrowError(l, r, str(exc)) from None
|
||||||
|
|
||||||
newtypevar.locations.extend(l.locations)
|
self.var_locations[n.ctx_id].extend(self.var_locations[l_ctx_id])
|
||||||
newtypevar.locations.extend(r.locations)
|
self.var_locations[n.ctx_id].extend(self.var_locations[r_ctx_id])
|
||||||
|
|
||||||
# Make pointer locations to the constraints and locations
|
# ##
|
||||||
# so they get linked together throughout the unification
|
# And unify (or entangle) the old ones
|
||||||
|
|
||||||
l.constraints = newtypevar.constraints
|
# First update the IDs, so they all point to the new list
|
||||||
l.locations = newtypevar.locations
|
for type_var in l_r_var_list:
|
||||||
|
type_var.ctx_id = n.ctx_id
|
||||||
|
|
||||||
r.constraints = newtypevar.constraints
|
# Update our registry of TypeVars by ID, so we can find them
|
||||||
r.locations = newtypevar.locations
|
# on the next unify
|
||||||
|
self.vars_by_id[n.ctx_id].extend(l_r_var_list)
|
||||||
|
|
||||||
return
|
# Then delete the old values for the now gone variables
|
||||||
|
# Do this last, so exceptions thrown in the code above
|
||||||
|
# still have a valid context
|
||||||
|
del self.var_constraints[l_ctx_id]
|
||||||
|
del self.var_constraints[r_ctx_id]
|
||||||
|
del self.var_locations[l_ctx_id]
|
||||||
|
del self.var_locations[r_ctx_id]
|
||||||
|
|
||||||
def simplify(inp: TypeVar) -> Optional[str]:
|
def simplify(inp: TypeVar) -> Optional[str]:
|
||||||
tc_prim = inp.constraints.get(TypeConstraintPrimitive)
|
tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive)
|
||||||
tc_bits = inp.constraints.get(TypeConstraintBitWidth)
|
tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth)
|
||||||
tc_sign = inp.constraints.get(TypeConstraintSigned)
|
tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned)
|
||||||
|
|
||||||
if tc_prim is None:
|
if tc_prim is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
2
pylintrc
2
pylintrc
@ -1,5 +1,5 @@
|
|||||||
[MASTER]
|
[MASTER]
|
||||||
disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223
|
disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223
|
||||||
|
|
||||||
max-line-length=180
|
max-line-length=180
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,21 @@ import pytest
|
|||||||
from .helpers import Suite
|
from .helpers import Suite
|
||||||
|
|
||||||
@pytest.mark.integration_test
|
@pytest.mark.integration_test
|
||||||
def test_i32():
|
def test_i32_asis():
|
||||||
|
code_py = """
|
||||||
|
CONSTANT: i32 = 13
|
||||||
|
|
||||||
|
@exported
|
||||||
|
def testEntry() -> i32:
|
||||||
|
return CONSTANT
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = Suite(code_py).run_code()
|
||||||
|
|
||||||
|
assert 13 == result.returned_value
|
||||||
|
|
||||||
|
@pytest.mark.integration_test
|
||||||
|
def test_i32_binop():
|
||||||
code_py = """
|
code_py = """
|
||||||
CONSTANT: i32 = 13
|
CONSTANT: i32 = 13
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user