From 7ce3f0f11c3d8e2bd3ce325cbc142d7e35c8422c Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sun, 18 Dec 2022 15:05:41 +0100 Subject: [PATCH] Started on struct constants --- phasm/codestyle.py | 6 ++ phasm/ourlang.py | 17 ++++++ phasm/parser.py | 68 ++++++++++++++++------ phasm/type3/constraints.py | 29 +++++++-- phasm/type3/constraintsgenerator.py | 2 +- tests/integration/test_lang/test_struct.py | 17 ++++-- 6 files changed, 111 insertions(+), 28 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 9b4ea01..1fbc7e4 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -67,6 +67,12 @@ def expression(inp: ourlang.Expression) -> str: for x in inp.value ) + ', )' + if isinstance(inp, ourlang.ConstantStruct): + return inp.struct_name + '(' + ', '.join( + expression(x) + for x in inp.value + ) + ')' + if isinstance(inp, ourlang.VariableReference): return str(inp.variable.name) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 696cf9e..eb57978 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -62,6 +62,23 @@ class ConstantTuple(Constant): def __repr__(self) -> str: return f'ConstantTuple({repr(self.value)})' +class ConstantStruct(Constant): + """ + A Struct constant value expression within a statement + """ + __slots__ = ('struct_name', 'value', ) + + struct_name: str + value: List[ConstantPrimitive] + + def __init__(self, struct_name: str, value: List[ConstantPrimitive]) -> None: # FIXME: Struct of structs? + super().__init__() + self.struct_name = struct_name + self.value = value + + def __repr__(self) -> str: + return f'ConstantStruct({repr(self.struct_name)}, {repr(self.value)})' + class VariableReference(Expression): """ An variable reference expression within a statement diff --git a/phasm/parser.py b/phasm/parser.py index 3f0afbb..9e57094 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -16,7 +16,7 @@ from .ourlang import ( Expression, BinaryOp, - ConstantPrimitive, ConstantTuple, + ConstantPrimitive, ConstantTuple, ConstantStruct, FunctionCall, AccessStructMember, Subscript, StructDefinition, StructConstructor, @@ -184,9 +184,9 @@ class OurVisitor: def pre_visit_Module_AnnAssign(self, module: Module, node: ast.AnnAssign) -> ModuleConstantDef: if not isinstance(node.target, ast.Name): - _raise_static_error(node, 'Must be name') + _raise_static_error(node.target, 'Must be name') if not isinstance(node.target.ctx, ast.Store): - _raise_static_error(node, 'Must be load context') + _raise_static_error(node.target, 'Must be store context') if isinstance(node.value, ast.Constant): type3 = self.visit_type(module, node.annotation) @@ -220,6 +220,45 @@ class OurVisitor: data_block, ) + if isinstance(node.value, ast.Call): + # Struct constant + # Stored in memory like a tuple, so much of the code is the same + + if not isinstance(node.value.func, ast.Name): + _raise_static_error(node.value.func, 'Must be name') + if not isinstance(node.value.func.ctx, ast.Load): + _raise_static_error(node.value.func, 'Must be load context') + + if not node.value.func.id in module.struct_definitions: + _raise_static_error(node.value.func, 'Undefined struct') + + if node.value.keywords: + _raise_static_error(node.value.func, 'Cannot use keywords') + + if not isinstance(node.annotation, ast.Name): + _raise_static_error(node.annotation, 'Must be name') + + struct_data = [ + self.visit_Module_Constant(module, arg_node) + for arg_node in node.value.args + if isinstance(arg_node, ast.Constant) + ] + if len(node.value.args) != len(struct_data): + _raise_static_error(node, 'Struct arguments must be constants') + + # Allocate the data + data_block = ModuleDataBlock(struct_data) + module.data.blocks.append(data_block) + + # Then return the constant as a pointer + return ModuleConstantDef( + node.target.id, + node.lineno, + self.visit_type(module, node.annotation), + ConstantStruct(node.value.func.id, struct_data), + data_block, + ) + raise NotImplementedError('TODO: Broken after new typing system') # if isinstance(exp_type, TypeTuple): @@ -549,31 +588,22 @@ class OurVisitor: return result def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: - del module - del function - if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must reference a name') if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if not node.value.id in our_locals: - _raise_static_error(node, f'Undefined variable {node.value.id}') + varref = self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value) + if not isinstance(varref, VariableReference): + _raise_static_error(node.value, 'Must refer to variable') - param = our_locals[node.value.id] - - node_typ = param.type3 - if not isinstance(node_typ, type3types.StructType3): - _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') - - member = node_typ.members.get(node.attr) - if member is None: - _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') + if not isinstance(varref.variable.type3, type3types.StructType3): + _raise_static_error(node.value, 'Must refer to struct') return AccessStructMember( - VariableReference(param), - node_typ, + varref, + varref.variable.type3, node.attr, ) diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index c097771..9b7d73e 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -203,9 +203,14 @@ class LiteralFitsConstraint(ConstraintBase): __slots__ = ('type3', 'literal', ) type3: types.Type3OrPlaceholder - literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple] + literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct] - def __init__(self, type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple], comment: Optional[str] = None) -> None: + def __init__( + self, + type3: types.Type3OrPlaceholder, + literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct], + comment: Optional[str] = None, + ) -> None: super().__init__(comment=comment) self.type3 = type3 @@ -226,7 +231,7 @@ class LiteralFitsConstraint(ConstraintBase): 'f64': None, } - def _check(type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple]) -> CheckResult: + def _check(type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct]) -> CheckResult: if isinstance(type3, types.PlaceholderForType): if type3 not in smap: return RequireTypeSubstitutes() @@ -275,7 +280,23 @@ class LiteralFitsConstraint(ConstraintBase): return None - raise NotImplementedError + if isinstance(type3, types.StructType3): + if not isinstance(literal, ourlang.ConstantStruct): + return Error('Must be struct') + + assert isinstance(val, list) # type hint + + if len(type3.members) != len(val): + return Error('Struct element count mismatch') + + for elt_typ, elt_lit in zip(type3.members.values(), val): + res = _check(elt_typ, elt_lit) + if res is not None: + return res + + return None + + raise NotImplementedError(type3, literal) return _check(self.type3, self.literal) diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index 2878bd1..d28754a 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -23,7 +23,7 @@ def phasm_type3_generate_constraints(inp: ourlang.Module) -> List[ConstraintBase return [*module(ctx, inp)] def constant(ctx: Context, inp: ourlang.Constant) -> Generator[ConstraintBase, None, None]: - if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, )): + if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct)): yield LiteralFitsConstraint(inp.type3, inp) return diff --git a/tests/integration/test_lang/test_struct.py b/tests/integration/test_lang/test_struct.py index bb70323..cffd29e 100644 --- a/tests/integration/test_lang/test_struct.py +++ b/tests/integration/test_lang/test_struct.py @@ -1,12 +1,9 @@ import pytest -from phasm.exceptions import StaticError from phasm.type3.entry import Type3Exception -from phasm.parser import phasm_parse - from ..constants import ( - ALL_INT_TYPES + ALL_INT_TYPES, TYPE_MAP ) from ..helpers import Suite @@ -88,6 +85,18 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32: assert 545 == result.returned_value +@pytest.mark.integration_test +def test_type_mismatch_arg_module_constant(): + code_py = """ +class Struct: + param: f32 + +STRUCT: Struct = Struct(1) +""" + + with pytest.raises(Type3Exception, match='todo'): + Suite(code_py).run_code() + @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) def test_type_mismatch_struct_member(type_):