This commit is contained in:
Johan B.W. de Vries 2023-01-07 16:20:33 +01:00
parent 0fa076bb93
commit 14b8065974
9 changed files with 136 additions and 81 deletions

View File

@ -31,13 +31,15 @@ def type3(inp: Type3OrPlaceholder) -> str:
return '(' + ', '.join( return '(' + ', '.join(
type3(x) type3(x)
for x in inp.args for x in inp.args
if isinstance(x, Type3) # Skip ints, not allowed here anyhow
) + ', )' ) + ', )'
if inp.base == type3types.static_array: if inp.base == type3types.static_array:
assert 1 == len(inp.args) assert 2 == len(inp.args)
assert isinstance(inp.args[0], Type3), TYPE3_ASSERTION_ERROR assert isinstance(inp.args[0], Type3), TYPE3_ASSERTION_ERROR
assert isinstance(inp.args[1], type3types.IntType3), TYPE3_ASSERTION_ERROR
return inp.args[0].name + '[3]' # FIXME: Where to store this value? return inp.args[0].name + '[' + inp.args[1].name + ']'
return inp.name return inp.name

View File

@ -386,16 +386,28 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp.varref.type3, type3types.AppliedType3): if isinstance(inp.varref.type3, type3types.AppliedType3):
if inp.varref.type3.base == type3types.static_array: if inp.varref.type3.base == type3types.static_array:
assert 1 == len(inp.varref.type3.args) assert 2 == len(inp.varref.type3.args)
el_type = inp.varref.type3.args[0] el_type = inp.varref.type3.args[0]
assert isinstance(el_type, type3types.Type3) assert isinstance(el_type, type3types.Type3)
el_len = inp.varref.type3.args[1]
assert isinstance(el_len, type3types.IntType3)
# OPTIMIZE: If index is a constant, we can use offset instead of multiply # OPTIMIZE: If index is a constant, we can use offset instead of multiply
# and we don't need to do the out of bounds check
# FIXME: Out of bounds check
expression(wgn, inp.varref) expression(wgn, inp.varref)
tmp_var = wgn.temp_var_i32('index')
expression(wgn, inp.index) expression(wgn, inp.index)
wgn.local.tee(tmp_var)
# Out of bounds check based on el_len.value
wgn.i32.const(el_len.value)
wgn.i32.ge_u()
with wgn.if_():
wgn.unreachable(comment='Out of bounds')
wgn.local.get(tmp_var)
wgn.i32.const(_calculate_alloc_size(el_type)) wgn.i32.const(_calculate_alloc_size(el_type))
wgn.i32.mul() wgn.i32.mul()
wgn.i32.add() wgn.i32.add()
@ -804,9 +816,12 @@ def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3])
if typ.base is type3types.tuple: if typ.base is type3types.tuple:
size = 0 size = 0
for arg in typ.args: for arg in typ.args:
assert not isinstance(arg, type3types.IntType3)
if isinstance(arg, type3types.PlaceholderForType): if isinstance(arg, type3types.PlaceholderForType):
assert not arg.resolve_as is None assert not arg.resolve_as is None
arg = arg.resolve_as arg = arg.resolve_as
size += _calculate_alloc_size(arg) size += _calculate_alloc_size(arg)
return size return size

View File

@ -583,7 +583,7 @@ class OurVisitor:
if not isinstance(node.value, ast.Name): if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must be name') _raise_static_error(node, 'Must be name')
if isinstance(node.slice, ast.Slice): if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index') # FIXME: Do we use type level length? _raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice, ast.Constant): if not isinstance(node.slice, ast.Constant):
_raise_static_error(node, 'Must subscript using a constant index') _raise_static_error(node, 'Must subscript using a constant index')
if not isinstance(node.slice.value, int): if not isinstance(node.slice.value, int):
@ -596,7 +596,7 @@ class OurVisitor:
return type3types.AppliedType3( return type3types.AppliedType3(
type3types.static_array, type3types.static_array,
[self.visit_type(module, node.value)], [self.visit_type(module, node.value), type3types.IntType3(node.slice.value)],
) )
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):

View File

@ -77,7 +77,7 @@ class ConstraintBase:
This function can return None, if the constraint holds, but no new This function can return None, if the constraint holds, but no new
information was deduced from evaluating this constraint. information was deduced from evaluating this constraint.
""" """
raise NotImplementedError raise NotImplementedError(self.__class__, self.check)
def human_readable(self) -> HumanReadableRet: def human_readable(self) -> HumanReadableRet:
""" """
@ -104,6 +104,10 @@ class SameTypeConstraint(ConstraintBase):
placeholders = [] placeholders = []
do_applied_placeholder_check: bool = False do_applied_placeholder_check: bool = False
for typ in self.type_list: for typ in self.type_list:
if isinstance(typ, types.IntType3):
known_types.append(typ)
continue
if isinstance(typ, (types.PrimitiveType3, types.StructType3, )): if isinstance(typ, (types.PrimitiveType3, types.StructType3, )):
known_types.append(typ) known_types.append(typ)
continue continue
@ -177,6 +181,46 @@ class SameTypeConstraint(ConstraintBase):
return f'SameTypeConstraint({args}, comment={repr(self.comment)})' return f'SameTypeConstraint({args}, comment={repr(self.comment)})'
class IntegerCompareConstraint(ConstraintBase):
"""
Verifies that the given IntType3 are in order (<=)
"""
__slots__ = ('int_type3_list', )
int_type3_list: List[types.IntType3]
def __init__(self, *int_type3: types.IntType3, comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
assert len(int_type3) > 1
self.int_type3_list = [*int_type3]
def check(self) -> CheckResult:
val_list = [x.value for x in self.int_type3_list]
prev_val = val_list.pop(0)
for next_val in val_list:
if prev_val > next_val:
return Error(f'{prev_val} must be less or equal than {next_val}')
prev_val = next_val
return None
def human_readable(self) -> HumanReadableRet:
return (
' <= '.join('{t' + str(idx) + '}' for idx in range(len(self.int_type3_list))),
{
't' + str(idx): typ
for idx, typ in enumerate(self.int_type3_list)
},
)
def __repr__(self) -> str:
args = ', '.join(repr(x) for x in self.int_type3_list)
return f'IntegerCompareConstraint({args}, comment={repr(self.comment)})'
class CastableConstraint(ConstraintBase): class CastableConstraint(ConstraintBase):
""" """
A type can be cast to another type A type can be cast to another type
@ -363,11 +407,11 @@ class LiteralFitsConstraint(ConstraintBase):
if not isinstance(self.literal, ourlang.ConstantTuple): if not isinstance(self.literal, ourlang.ConstantTuple):
return Error('Must be tuple') return Error('Must be tuple')
assert 1 == len(self.type3.args) assert 2 == len(self.type3.args)
assert isinstance(self.type3.args[1], types.IntType3)
# FIXME: How to store type level length? if self.type3.args[1].value != len(self.literal.value):
# if len(self.type3.args) != len(self.literal.value): return Error('Member count mismatch')
# return Error('Tuple element count mismatch')
res = [] res = []
@ -448,11 +492,24 @@ class CanBeSubscriptedConstraint(ConstraintBase):
if isinstance(self.type3, types.AppliedType3): if isinstance(self.type3, types.AppliedType3):
if self.type3.base == types.static_array: if self.type3.base == types.static_array:
return [ result: List[ConstraintBase] = [
SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
SameTypeConstraint(self.type3.args[0], self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), SameTypeConstraint(self.type3.args[0], self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
] ]
if isinstance(self.index, ourlang.ConstantPrimitive):
assert isinstance(self.index.value, int)
assert isinstance(self.type3.args[1], types.IntType3)
result.append(
IntegerCompareConstraint(
types.IntType3(0), types.IntType3(self.index.value), types.IntType3(self.type3.args[1].value - 1),
comment='Subscript static array must fit the size of the array'
)
)
return result
if self.type3.base == types.tuple: if self.type3.base == types.tuple:
if not isinstance(self.index, ourlang.ConstantPrimitive): if not isinstance(self.index, ourlang.ConstantPrimitive):
return Error('Must index with literal') return Error('Must index with literal')

View File

@ -8,7 +8,7 @@ from .. import ourlang
from .constraints import ConstraintBase, Error, RequireTypeSubstitutes, SameTypeConstraint, SubstitutionMap from .constraints import ConstraintBase, Error, RequireTypeSubstitutes, SameTypeConstraint, SubstitutionMap
from .constraintsgenerator import phasm_type3_generate_constraints from .constraintsgenerator import phasm_type3_generate_constraints
from .types import AppliedType3, PlaceholderForType, PrimitiveType3, StructType3, Type3, Type3OrPlaceholder from .types import AppliedType3, IntType3, PlaceholderForType, PrimitiveType3, StructType3, Type3, Type3OrPlaceholder
MAX_RESTACK_COUNT = 100 MAX_RESTACK_COUNT = 100
@ -125,7 +125,7 @@ def print_constraint(placeholder_id_map: Dict[int, str], constraint: ConstraintB
print('- ' + txt.format(**act_fmt)) print('- ' + txt.format(**act_fmt))
def get_printable_type_name(inp: Type3OrPlaceholder, placeholder_id_map: Dict[int, str]) -> str: def get_printable_type_name(inp: Type3OrPlaceholder, placeholder_id_map: Dict[int, str]) -> str:
if isinstance(inp, (PrimitiveType3, StructType3, )): if isinstance(inp, (PrimitiveType3, StructType3, IntType3, )):
return inp.name return inp.name
if isinstance(inp, PlaceholderForType): if isinstance(inp, PlaceholderForType):

View File

@ -69,6 +69,30 @@ class PrimitiveType3(Type3):
__slots__ = () __slots__ = ()
class IntType3(Type3):
"""
Sometimes you can have an int as type, e.g. when using static arrays
"""
__slots__ = ('value', )
value: int
def __init__(self, value: int) -> None:
super().__init__(str(value))
assert 0 <= value
self.value = value
def __eq__(self, other: Any) -> bool:
if isinstance(other, IntType3):
return self.value == other.value
if isinstance(other, Type3):
return False
raise NotImplementedError
class PlaceholderForType: class PlaceholderForType:
""" """
A placeholder type, for when we don't know the final type yet A placeholder type, for when we don't know the final type yet

View File

@ -4,26 +4,8 @@ from ..helpers import Suite
@pytest.mark.slow_integration_test @pytest.mark.slow_integration_test
def test_fib(): def test_fib():
code_py = """ with open('./examples/fib.py', 'r', encoding='UTF-8') as fil:
def helper(n: i32, a: i32, b: i32) -> i32: code_py = "\n" + fil.read()
if n < 1:
return a + b
return helper(n - 1, a + b, a)
def fib(n: i32) -> i32:
if n == 0:
return 0
if n == 1:
return 1
return helper(n - 1, 0, 1)
@exported
def testEntry() -> i32:
return fib(40)
"""
result = Suite(code_py).run_code() result = Suite(code_py).run_code()

View File

@ -75,8 +75,7 @@ def testEntry() -> {type_}:
assert 32.125 == result.returned_value assert 32.125 == result.returned_value
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.skip('Awaiting result of Type3 experiment') def test_module_constant_type_failure():
def test_module_constant_entanglement():
code_py = """ code_py = """
CONSTANT: u8 = 1000 CONSTANT: u8 = 1000
@ -85,7 +84,7 @@ def testEntry() -> u32:
return 14 return 14
""" """
with pytest.raises(Type3Exception, match='u8.*1000'): with pytest.raises(Type3Exception, match=r'Must fit in 1 byte\(s\)'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test

View File

@ -37,26 +37,6 @@ def testEntry() -> {type_}:
assert 57 == result.returned_value assert 57 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.skip('To decide: What to do on out of index?')
@pytest.mark.parametrize('type_', COMPLETE_NUMERIC_TYPES)
def test_static_array_indexed(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT, 0, 1, 2)
def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}:
return array[i0] + array[i1] + array[i2]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) @pytest.mark.parametrize('type_', COMPLETE_INT_TYPES)
def test_function_call_int(type_): def test_function_call_int(type_):
@ -146,7 +126,7 @@ def testEntry() -> u32:
return CONSTANT return CONSTANT
""" """
with pytest.raises(Type3Exception, match=r'static_array \(u8\) must be u32 instead'): with pytest.raises(Type3Exception, match=r'static_array \(u8\) \(3\) must be u32 instead'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
@ -163,7 +143,7 @@ def testEntry() -> u8:
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
def test_module_constant_type_mismatch_index_out_of_range(): def test_module_constant_type_mismatch_index_out_of_range_constant():
code_py = """ code_py = """
CONSTANT: u8[3] = (24, 57, 80, ) CONSTANT: u8[3] = (24, 57, 80, )
@ -172,20 +152,33 @@ def testEntry() -> u8:
return CONSTANT[3] return CONSTANT[3]
""" """
with pytest.raises(Type3Exception, match='Type cannot be subscripted with index 3:'): with pytest.raises(Type3Exception, match='3 must be less or equal than 2'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test
def test_module_constant_type_mismatch_index_out_of_range_variable():
code_py = """
CONSTANT: u8[3] = (24, 57, 80, )
@exported
def testEntry(x: u32) -> u8:
return CONSTANT[x]
"""
with pytest.raises(RuntimeError):
Suite(code_py).run_code(3)
@pytest.mark.integration_test @pytest.mark.integration_test
def test_static_array_constant_too_few_values(): def test_static_array_constant_too_few_values():
code_py = """ code_py = """
CONSTANT: u8[3] = (24, 57, ) CONSTANT: u8[4] = (24, 57, )
@exported @exported
def testEntry() -> i32: def testEntry() -> i32:
return 0 return 0
""" """
with pytest.raises(Type3Exception, match='Member count does not match'): with pytest.raises(Type3Exception, match='Member count mismatch'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
@ -198,22 +191,5 @@ def testEntry() -> i32:
return 0 return 0
""" """
with pytest.raises(Type3Exception, match='Member count does not match'): with pytest.raises(Type3Exception, match='Member count mismatch'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('To decide: What to do on out of index? Should be a panic.')
def test_static_array_index_out_of_bounds():
code_py = """
CONSTANT0: u32[3] = (24, 57, 80, )
CONSTANT1: u32[16] = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, )
@exported
def testEntry() -> u32:
return CONSTANT0[16]
"""
result = Suite(code_py).run_code()
assert 0 == result.returned_value