diff --git a/phasm/codestyle.py b/phasm/codestyle.py index bbb7a05..fc40868 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -63,7 +63,7 @@ def expression(inp: ourlang.Expression) -> str: ) + ', )' if isinstance(inp, ourlang.ConstantStruct): - return inp.struct_name + '(' + ', '.join( + return inp.struct_type3.name + '(' + ', '.join( expression(x) for x in inp.value ) + ')' diff --git a/phasm/ourlang.py b/phasm/ourlang.py index f103555..df97b23 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -98,21 +98,21 @@ class ConstantStruct(ConstantMemoryStored): """ A Struct constant value expression within a statement """ - __slots__ = ('struct_name', 'value', ) + __slots__ = ('struct_type3', 'value', ) - struct_name: str + struct_type3: Type3 value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']] - def __init__(self, struct_name: str, value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']], data_block: 'ModuleDataBlock') -> None: + def __init__(self, struct_type3: Type3, value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']], data_block: 'ModuleDataBlock') -> None: super().__init__(data_block) - self.struct_name = struct_name + self.struct_type3 = struct_type3 self.value = value def __repr__(self) -> str: # Do not repr the whole ModuleDataBlock # As this has a reference back to this constant for its data # which it needs to compile the data into the program - return f'ConstantStruct({repr(self.struct_name)}, {repr(self.value)}, @{repr(self.data_block.address)})' + return f'ConstantStruct({self.struct_type3!r}, {self.value!r}, @{self.data_block.address!r})' class VariableReference(Expression): """ diff --git a/phasm/parser.py b/phasm/parser.py index e32563a..9944dc2 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -584,7 +584,8 @@ class OurVisitor: if not isinstance(node.func.ctx, ast.Load): _raise_static_error(node.func, 'Must be load context') - if node.func.id not in module.struct_definitions: + struct_def = module.struct_definitions.get(node.func.id) + if struct_def is None: _raise_static_error(node.func, 'Undefined struct') if node.keywords: @@ -600,7 +601,7 @@ class OurVisitor: data_block = ModuleDataBlock(struct_data) module.data.blocks.append(data_block) - return ConstantStruct(node.func.id, struct_data, data_block) + return ConstantStruct(struct_def.struct_type3, struct_data, data_block) _not_implemented(node.kind is None, 'Constant.kind') diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index c5bf372..4a9541b 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -392,10 +392,16 @@ class LiteralFitsConstraint(ConstraintBase): # gets updated when we figure out the type of the # expression the literal is used in res.extend( - SameTypeConstraint(x_t, PlaceholderForType([y]), comment=f'{self.literal.struct_name}.{x_n}') + SameTypeConstraint(x_t, PlaceholderForType([y]), comment=f'{self.literal.struct_type3.name}.{x_n}') for (x_n, x_t, ), y in zip(st_args, self.literal.value, strict=True) ) + res.append(SameTypeConstraint( + self.literal.struct_type3, + self.type3, + comment='Struct types must match', + )) + return res def _generate_tuple(self, tp_args: tuple[Type3, ...]) -> CheckResult: diff --git a/tests/integration/test_lang/test_struct.py b/tests/integration/test_lang/test_struct.py index 6b211e9..a0396ab 100644 --- a/tests/integration/test_lang/test_struct.py +++ b/tests/integration/test_lang/test_struct.py @@ -64,6 +64,36 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32: assert 545 == result.returned_value +@pytest.mark.integration_test +def test_type_mismatch_struct_call_root(): + code_py = """ +class CheckedValueBlue: + value: i32 + +class CheckedValueRed: + value: i32 + +CONST: CheckedValueBlue = CheckedValueRed(1) +""" + + with pytest.raises(Type3Exception, match='CheckedValueBlue must be CheckedValueRed instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_type_mismatch_struct_call_nested(): + code_py = """ +class CheckedValueBlue: + value: i32 + +class CheckedValueRed: + value: i32 + +CONST: (CheckedValueBlue, u32, ) = (CheckedValueRed(1), 16, ) +""" + + with pytest.raises(Type3Exception, match='CheckedValueBlue must be CheckedValueRed instead'): + Suite(code_py).run_code() + @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) def test_type_mismatch_struct_member(type_):