Improved unification
This commit is contained in:
parent
6f3d9a5bcc
commit
b2816164f9
@ -132,20 +132,25 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
Compile: Any expression
|
||||
"""
|
||||
if isinstance(inp, ourlang.ConstantPrimitive):
|
||||
assert inp.type_var is not None
|
||||
|
||||
stp = typing.simplify(inp.type_var)
|
||||
if stp is None:
|
||||
raise NotImplementedError(f'Constants with type {inp.type_var}')
|
||||
|
||||
if stp == 'u8':
|
||||
# No native u8 type - treat as i32, with caution
|
||||
assert isinstance(inp.value, int)
|
||||
wgn.i32.const(inp.value)
|
||||
return
|
||||
|
||||
if stp in ('i32', 'u32'):
|
||||
assert isinstance(inp.value, int)
|
||||
wgn.i32.const(inp.value)
|
||||
return
|
||||
|
||||
if stp in ('i64', 'u64'):
|
||||
assert isinstance(inp.value, int)
|
||||
wgn.i64.const(inp.value)
|
||||
return
|
||||
|
||||
@ -321,7 +326,8 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
|
||||
"""
|
||||
Compile: Fold expression
|
||||
"""
|
||||
mtyp = LOAD_STORE_TYPE_MAP.get(inp.base.type.__class__)
|
||||
assert inp.base.type_var is not None
|
||||
mtyp = typing.simplify(inp.base.type_var)
|
||||
if mtyp is None:
|
||||
# In the future might extend this by having structs or tuples
|
||||
# as members of struct or tuples
|
||||
|
||||
@ -361,11 +361,12 @@ class ModuleConstantDef:
|
||||
"""
|
||||
A constant definition within a module
|
||||
"""
|
||||
__slots__ = ('name', 'lineno', 'type', 'constant', 'data_block', )
|
||||
__slots__ = ('name', 'lineno', 'type', 'type_var', 'constant', 'data_block', )
|
||||
|
||||
name: str
|
||||
lineno: int
|
||||
type: TypeBase
|
||||
type_var: Optional[TypeVar]
|
||||
constant: Constant
|
||||
data_block: Optional['ModuleDataBlock']
|
||||
|
||||
@ -373,6 +374,7 @@ class ModuleConstantDef:
|
||||
self.name = name
|
||||
self.lineno = lineno
|
||||
self.type = type_
|
||||
self.type_var = None
|
||||
self.constant = constant
|
||||
self.data_block = data_block
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
|
||||
return inp.variable.type_var
|
||||
|
||||
if isinstance(inp, ourlang.BinaryOp):
|
||||
if inp.operator not in ('+', '-', '|', '&', '^'):
|
||||
if inp.operator not in ('+', '-', '*', '|', '&', '^'):
|
||||
raise NotImplementedError(expression, inp, inp.operator)
|
||||
|
||||
left = expression(ctx, inp.left)
|
||||
@ -55,6 +55,10 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
|
||||
|
||||
return inp.function.returns_type_var
|
||||
|
||||
if isinstance(inp, ourlang.ModuleConstantReference):
|
||||
assert inp.definition.type_var is not None
|
||||
return inp.definition.type_var
|
||||
|
||||
raise NotImplementedError(expression, inp)
|
||||
|
||||
def function(ctx: 'Context', inp: ourlang.Function) -> None:
|
||||
@ -64,7 +68,11 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None:
|
||||
|
||||
assert inp.returns_type_var is not None
|
||||
ctx.unify(inp.returns_type_var, typ)
|
||||
return
|
||||
|
||||
def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None:
|
||||
inp.type_var = _convert_old_type(ctx, inp.type, inp.name)
|
||||
constant(ctx, inp.constant)
|
||||
ctx.unify(inp.type_var, inp.constant.type_var)
|
||||
|
||||
def module(inp: ourlang.Module) -> None:
|
||||
ctx = Context()
|
||||
@ -74,6 +82,9 @@ def module(inp: ourlang.Module) -> None:
|
||||
for param in func.posonlyargs:
|
||||
param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}')
|
||||
|
||||
for cdef in inp.constant_defs.values():
|
||||
module_constant_def(ctx, cdef)
|
||||
|
||||
for func in inp.functions.values():
|
||||
function(ctx, func)
|
||||
|
||||
|
||||
105
phasm/typing.py
105
phasm/typing.py
@ -301,62 +301,107 @@ class TypeConstraintBitWidth(TypeConstraintBase):
|
||||
return f'BitWidth={self.minb}..{self.maxb}'
|
||||
|
||||
class TypeVar:
|
||||
def __init__(self, ctx: 'Context') -> None:
|
||||
self.context = ctx
|
||||
self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {}
|
||||
self.locations: List[str] = []
|
||||
__slots__ = ('ctx', 'ctx_id', )
|
||||
|
||||
ctx: 'Context'
|
||||
ctx_id: int
|
||||
|
||||
def __init__(self, ctx: 'Context', ctx_id: int) -> None:
|
||||
self.ctx = ctx
|
||||
self.ctx_id = ctx_id
|
||||
|
||||
def add_constraint(self, newconst: TypeConstraintBase) -> None:
|
||||
if newconst.__class__ in self.constraints:
|
||||
self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst)
|
||||
csts = self.ctx.var_constraints[self.ctx_id]
|
||||
|
||||
if newconst.__class__ in csts:
|
||||
csts[newconst.__class__] = csts[newconst.__class__].narrow(newconst)
|
||||
else:
|
||||
self.constraints[newconst.__class__] = newconst
|
||||
csts[newconst.__class__] = newconst
|
||||
|
||||
def add_location(self, ref: str) -> None:
|
||||
self.locations.append(ref)
|
||||
self.ctx.var_locations[self.ctx_id].append(ref)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
'TypeVar<'
|
||||
+ '; '.join(map(repr, self.constraints.values()))
|
||||
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
|
||||
+ '; locations: '
|
||||
+ ', '.join(self.locations)
|
||||
+ ', '.join(self.ctx.var_locations[self.ctx_id])
|
||||
+ '>'
|
||||
)
|
||||
|
||||
class Context:
|
||||
def __init__(self) -> None:
|
||||
# Variables are unified (or entangled, if you will)
|
||||
# that means that each TypeVar within a context has an ID,
|
||||
# and all TypeVars with the same ID are the same TypeVar,
|
||||
# even if they are a different instance
|
||||
self.next_ctx_id = 1
|
||||
self.vars_by_id: Dict[int, List[TypeVar]] = {}
|
||||
|
||||
# Store the TypeVar properties as a lookup
|
||||
# so we can update these when unifying
|
||||
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
|
||||
self.var_locations: Dict[int, List[str]] = {}
|
||||
|
||||
def new_var(self) -> TypeVar:
|
||||
return TypeVar(self)
|
||||
ctx_id = self.next_ctx_id
|
||||
self.next_ctx_id += 1
|
||||
|
||||
result = TypeVar(self, ctx_id)
|
||||
|
||||
self.vars_by_id[ctx_id] = [result]
|
||||
self.var_constraints[ctx_id] = {}
|
||||
self.var_locations[ctx_id] = []
|
||||
|
||||
return result
|
||||
|
||||
def unify(self, l: 'TypeVar', r: 'TypeVar') -> None:
|
||||
newtypevar = self.new_var()
|
||||
assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return
|
||||
|
||||
# Backup some values that we'll overwrite
|
||||
l_ctx_id = l.ctx_id
|
||||
r_ctx_id = r.ctx_id
|
||||
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
|
||||
|
||||
# Create a new TypeVar, with the combined contraints
|
||||
# and locations of the old ones
|
||||
n = self.new_var()
|
||||
|
||||
try:
|
||||
for const in l.constraints.values():
|
||||
newtypevar.add_constraint(const)
|
||||
for const in r.constraints.values():
|
||||
newtypevar.add_constraint(const)
|
||||
except TypingNarrowProtoError as ex:
|
||||
raise TypingNarrowError(l, r, str(ex)) from None
|
||||
for const in self.var_constraints[l_ctx_id].values():
|
||||
n.add_constraint(const)
|
||||
for const in self.var_constraints[r_ctx_id].values():
|
||||
n.add_constraint(const)
|
||||
except TypingNarrowProtoError as exc:
|
||||
raise TypingNarrowError(l, r, str(exc)) from None
|
||||
|
||||
newtypevar.locations.extend(l.locations)
|
||||
newtypevar.locations.extend(r.locations)
|
||||
self.var_locations[n.ctx_id].extend(self.var_locations[l_ctx_id])
|
||||
self.var_locations[n.ctx_id].extend(self.var_locations[r_ctx_id])
|
||||
|
||||
# Make pointer locations to the constraints and locations
|
||||
# so they get linked together throughout the unification
|
||||
# ##
|
||||
# And unify (or entangle) the old ones
|
||||
|
||||
l.constraints = newtypevar.constraints
|
||||
l.locations = newtypevar.locations
|
||||
# First update the IDs, so they all point to the new list
|
||||
for type_var in l_r_var_list:
|
||||
type_var.ctx_id = n.ctx_id
|
||||
|
||||
r.constraints = newtypevar.constraints
|
||||
r.locations = newtypevar.locations
|
||||
# Update our registry of TypeVars by ID, so we can find them
|
||||
# on the next unify
|
||||
self.vars_by_id[n.ctx_id].extend(l_r_var_list)
|
||||
|
||||
return
|
||||
# Then delete the old values for the now gone variables
|
||||
# Do this last, so exceptions thrown in the code above
|
||||
# still have a valid context
|
||||
del self.var_constraints[l_ctx_id]
|
||||
del self.var_constraints[r_ctx_id]
|
||||
del self.var_locations[l_ctx_id]
|
||||
del self.var_locations[r_ctx_id]
|
||||
|
||||
def simplify(inp: TypeVar) -> Optional[str]:
|
||||
tc_prim = inp.constraints.get(TypeConstraintPrimitive)
|
||||
tc_bits = inp.constraints.get(TypeConstraintBitWidth)
|
||||
tc_sign = inp.constraints.get(TypeConstraintSigned)
|
||||
tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive)
|
||||
tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth)
|
||||
tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned)
|
||||
|
||||
if tc_prim is None:
|
||||
return None
|
||||
|
||||
2
pylintrc
2
pylintrc
@ -1,5 +1,5 @@
|
||||
[MASTER]
|
||||
disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223
|
||||
disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223
|
||||
|
||||
max-line-length=180
|
||||
|
||||
|
||||
@ -3,7 +3,21 @@ import pytest
|
||||
from .helpers import Suite
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_i32():
|
||||
def test_i32_asis():
|
||||
code_py = """
|
||||
CONSTANT: i32 = 13
|
||||
|
||||
@exported
|
||||
def testEntry() -> i32:
|
||||
return CONSTANT
|
||||
"""
|
||||
|
||||
result = Suite(code_py).run_code()
|
||||
|
||||
assert 13 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_i32_binop():
|
||||
code_py = """
|
||||
CONSTANT: i32 = 13
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user