Fixed a missing constraint

This commit is contained in:
Johan B.W. de Vries 2022-12-24 19:25:30 +01:00
parent e456f55bb0
commit 17f538d8cc
3 changed files with 25 additions and 5 deletions

View File

@ -350,15 +350,17 @@ class CanBeSubscriptedConstraint(ConstraintBase):
"""
A value that is subscipted, i.e. a[0] (tuple) or a[b] (static array)
"""
__slots__ = ('type3', 'index', 'index_type3', )
__slots__ = ('ret_type3', 'type3', 'index', 'index_type3', )
ret_type3: types.Type3OrPlaceholder
type3: types.Type3OrPlaceholder
index: ourlang.Expression
index_type3: types.Type3OrPlaceholder
def __init__(self, type3: types.Type3OrPlaceholder, index: ourlang.Expression, comment: Optional[str] = None) -> None:
def __init__(self, ret_type3: types.Type3OrPlaceholder, type3: types.Type3OrPlaceholder, index: ourlang.Expression, comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
self.ret_type3 = ret_type3
self.type3 = type3
self.index = index
self.index_type3 = index.type3
@ -373,7 +375,8 @@ class CanBeSubscriptedConstraint(ConstraintBase):
if isinstance(self.type3, types.AppliedType3):
if self.type3.base == types.static_array:
return [
SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b')
SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
SameTypeConstraint(self.type3.args[0], self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
]
# FIXME: bytes

View File

@ -103,7 +103,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> Generator[ConstraintBas
yield from expression(ctx, inp.varref)
yield from expression(ctx, inp.index)
yield CanBeSubscriptedConstraint(inp.varref.type3, inp.index)
yield CanBeSubscriptedConstraint(inp.type3, inp.varref.type3, inp.index)
return
if isinstance(inp, ourlang.AccessStructMember):

View File

@ -96,7 +96,24 @@ def helper(array: {type_}[3]) -> {type_}:
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_function_call_element():
def test_function_call_element_ok():
code_py = """
CONSTANT: u64[3] = (250, 250000, 250000000, )
@exported
def testEntry() -> u64:
return helper(CONSTANT[0])
def helper(x: u64) -> u64:
return x
"""
result = Suite(code_py).run_code()
assert 250 == result.returned_value
@pytest.mark.integration_test
def test_function_call_element_type_mismatch():
code_py = """
CONSTANT: u64[3] = (250, 250000, 250000000, )