Fix: You could assign structs to each other

As long as the arguments matched at least.
This commit is contained in:
Johan B.W. de Vries 2025-05-12 19:40:41 +02:00
parent 67af569448
commit ac4b46bbe7
5 changed files with 46 additions and 9 deletions

View File

@ -63,7 +63,7 @@ def expression(inp: ourlang.Expression) -> str:
) + ', )'
if isinstance(inp, ourlang.ConstantStruct):
return inp.struct_name + '(' + ', '.join(
return inp.struct_type3.name + '(' + ', '.join(
expression(x)
for x in inp.value
) + ')'

View File

@ -98,21 +98,21 @@ class ConstantStruct(ConstantMemoryStored):
"""
A Struct constant value expression within a statement
"""
__slots__ = ('struct_name', 'value', )
__slots__ = ('struct_type3', 'value', )
struct_name: str
struct_type3: Type3
value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']]
def __init__(self, struct_name: str, value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']], data_block: 'ModuleDataBlock') -> None:
def __init__(self, struct_type3: Type3, value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']], data_block: 'ModuleDataBlock') -> None:
super().__init__(data_block)
self.struct_name = struct_name
self.struct_type3 = struct_type3
self.value = value
def __repr__(self) -> str:
# Do not repr the whole ModuleDataBlock
# As this has a reference back to this constant for its data
# which it needs to compile the data into the program
return f'ConstantStruct({repr(self.struct_name)}, {repr(self.value)}, @{repr(self.data_block.address)})'
return f'ConstantStruct({self.struct_type3!r}, {self.value!r}, @{self.data_block.address!r})'
class VariableReference(Expression):
"""

View File

@ -584,7 +584,8 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node.func, 'Must be load context')
if node.func.id not in module.struct_definitions:
struct_def = module.struct_definitions.get(node.func.id)
if struct_def is None:
_raise_static_error(node.func, 'Undefined struct')
if node.keywords:
@ -600,7 +601,7 @@ class OurVisitor:
data_block = ModuleDataBlock(struct_data)
module.data.blocks.append(data_block)
return ConstantStruct(node.func.id, struct_data, data_block)
return ConstantStruct(struct_def.struct_type3, struct_data, data_block)
_not_implemented(node.kind is None, 'Constant.kind')

View File

@ -392,10 +392,16 @@ class LiteralFitsConstraint(ConstraintBase):
# gets updated when we figure out the type of the
# expression the literal is used in
res.extend(
SameTypeConstraint(x_t, PlaceholderForType([y]), comment=f'{self.literal.struct_name}.{x_n}')
SameTypeConstraint(x_t, PlaceholderForType([y]), comment=f'{self.literal.struct_type3.name}.{x_n}')
for (x_n, x_t, ), y in zip(st_args, self.literal.value, strict=True)
)
res.append(SameTypeConstraint(
self.literal.struct_type3,
self.type3,
comment='Struct types must match',
))
return res
def _generate_tuple(self, tp_args: tuple[Type3, ...]) -> CheckResult:

View File

@ -64,6 +64,36 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
assert 545 == result.returned_value
@pytest.mark.integration_test
def test_type_mismatch_struct_call_root():
code_py = """
class CheckedValueBlue:
value: i32
class CheckedValueRed:
value: i32
CONST: CheckedValueBlue = CheckedValueRed(1)
"""
with pytest.raises(Type3Exception, match='CheckedValueBlue must be CheckedValueRed instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_type_mismatch_struct_call_nested():
code_py = """
class CheckedValueBlue:
value: i32
class CheckedValueRed:
value: i32
CONST: (CheckedValueBlue, u32, ) = (CheckedValueRed(1), 16, )
"""
with pytest.raises(Type3Exception, match='CheckedValueBlue must be CheckedValueRed instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64'])
def test_type_mismatch_struct_member(type_):