diff --git a/phasm/compiler.py b/phasm/compiler.py index fa34e2f..b7f45fa 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -171,7 +171,7 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> wgn.add_statement('nop', comment=f'{tmp_var.name} := ({comment_elements})') # Allocated the required amounts of bytes in memory - wgn.i32.const(_calculate_alloc_size(inp.type3)) + wgn.i32.const(_calculate_alloc_size(inp.type3, is_member=False)) wgn.call(stdlib_alloc.__alloc__) wgn.local.set(tmp_var) @@ -184,14 +184,19 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> assert element.type3 == exp_type3 - assert isinstance(exp_type3, type3types.PrimitiveType3), NotImplementedError('Tuple of applied types / structs') - mtyp = LOAD_STORE_TYPE_MAP[exp_type3.name] + if isinstance(exp_type3, type3types.AppliedType3) and exp_type3.base is type3types.tuple: + mtyp = 'i32' + else: + assert isinstance(exp_type3, type3types.PrimitiveType3), NotImplementedError('Tuple of applied types / structs') + mtyp = LOAD_STORE_TYPE_MAP[exp_type3.name] + wgn.add_statement('nop', comment='PRE') wgn.local.get(tmp_var) expression(wgn, element) wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset)) + wgn.add_statement('nop', comment='POST') - offset += _calculate_alloc_size(exp_type3) + offset += _calculate_alloc_size(exp_type3, is_member=True) # Return the allocated address wgn.local.get(tmp_var) @@ -445,8 +450,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: expression(wgn, inp.varref) - assert isinstance(el_type, type3types.PrimitiveType3), NotImplementedError('Tuple of applied types / structs') - mtyp = LOAD_STORE_TYPE_MAP[el_type.name] + if isinstance(el_type, type3types.AppliedType3) and el_type.base is type3types.tuple: + mtyp = 'i32' + else: + assert isinstance(el_type, type3types.PrimitiveType3), NotImplementedError('Tuple of applied types / structs') + mtyp = LOAD_STORE_TYPE_MAP[el_type.name] wgn.add_statement(f'{mtyp}.load', f'offset={offset}') return @@ -814,7 +822,7 @@ def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstruc # Return the allocated address wgn.local.get(tmp_var) -def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) -> int: +def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3], is_member: bool = False) -> int: if typ == type3types.u8: return 4 # FIXME: We allocate 4 bytes for every u8 since you load them into an i32 @@ -825,6 +833,10 @@ def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) return 8 if isinstance(typ, type3types.StructType3): + if is_member: + # Structs referred to by other structs or tuples are pointers + return 4 + return sum( _calculate_alloc_size(x) for x in typ.members.values() @@ -832,6 +844,10 @@ def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) if isinstance(typ, type3types.AppliedType3): if typ.base is type3types.tuple: + if is_member: + # tuples referred to by other structs or tuples are pointers + return 4 + size = 0 for arg in typ.args: assert not isinstance(arg, type3types.IntType3) @@ -840,7 +856,7 @@ def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) assert not arg.resolve_as is None arg = arg.resolve_as - size += _calculate_alloc_size(arg) + size += _calculate_alloc_size(arg, is_member=True) return size @@ -853,6 +869,6 @@ def _calculate_member_offset(struct_type3: type3types.StructType3, member: str) if member == mem: return result - result += _calculate_alloc_size(memtyp) + result += _calculate_alloc_size(memtyp, is_member=True) raise Exception(f'{member} not in {struct_type3}') diff --git a/phasm/parser.py b/phasm/parser.py index cfa2338..6e38628 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -442,7 +442,7 @@ class OurVisitor: arguments = [ self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_node) for arg_node in node.elts - if isinstance(arg_node, ast.Constant) + if isinstance(arg_node, (ast.Constant, ast.Tuple, )) ] if len(arguments) != len(node.elts): diff --git a/tests/integration/test_lang/test_tuple.py b/tests/integration/test_lang/test_tuple.py index a238932..bf8033d 100644 --- a/tests/integration/test_lang/test_tuple.py +++ b/tests/integration/test_lang/test_tuple.py @@ -65,6 +65,60 @@ def helper(x: u64) -> u64: assert 250000000 == result.returned_value +@pytest.mark.integration_test +def test_tuple(): + code_py = """ +def l1(c: (u64, )) -> u64: + return c[0] + +@exported +def testEntry() -> u64: + return l1((32, )) +""" + + result = Suite(code_py).run_code() + + assert 32 == result.returned_value + +@pytest.mark.integration_test +def test_tuple_of_tuple(): + code_py = """ +def l1(c: (u64, )) -> u64: + return c[0] + +def l2(c: ((u64, ), u64, )) -> u64: + return l1(c[0]) + +@exported +def testEntry() -> u64: + return l2(((64, ), 32, )) +""" + + result = Suite(code_py).run_code() + + assert 64 == result.returned_value + +@pytest.mark.integration_test +def test_tuple_of_tuple_of_tuple(): + code_py = """ +def l1(c: (u64, )) -> u64: + return c[0] + +def l2(c: ((u64, ), u64, )) -> u64: + return l1(c[0]) + +def l3(c: (((u64, ), u64, ), u64, )) -> u64: + return l2(c[0]) + +@exported +def testEntry() -> u64: + return l3((((128, ), 64, ), 32, )) +""" + + result = Suite(code_py).run_code() + + assert 128 == result.returned_value + @pytest.mark.integration_test def test_function_call_element_type_mismatch(): code_py = """