Started on struct constants
This commit is contained in:
parent
d18f1c6956
commit
7ce3f0f11c
@ -67,6 +67,12 @@ def expression(inp: ourlang.Expression) -> str:
|
||||
for x in inp.value
|
||||
) + ', )'
|
||||
|
||||
if isinstance(inp, ourlang.ConstantStruct):
|
||||
return inp.struct_name + '(' + ', '.join(
|
||||
expression(x)
|
||||
for x in inp.value
|
||||
) + ')'
|
||||
|
||||
if isinstance(inp, ourlang.VariableReference):
|
||||
return str(inp.variable.name)
|
||||
|
||||
|
||||
@ -62,6 +62,23 @@ class ConstantTuple(Constant):
|
||||
def __repr__(self) -> str:
|
||||
return f'ConstantTuple({repr(self.value)})'
|
||||
|
||||
class ConstantStruct(Constant):
|
||||
"""
|
||||
A Struct constant value expression within a statement
|
||||
"""
|
||||
__slots__ = ('struct_name', 'value', )
|
||||
|
||||
struct_name: str
|
||||
value: List[ConstantPrimitive]
|
||||
|
||||
def __init__(self, struct_name: str, value: List[ConstantPrimitive]) -> None: # FIXME: Struct of structs?
|
||||
super().__init__()
|
||||
self.struct_name = struct_name
|
||||
self.value = value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'ConstantStruct({repr(self.struct_name)}, {repr(self.value)})'
|
||||
|
||||
class VariableReference(Expression):
|
||||
"""
|
||||
An variable reference expression within a statement
|
||||
|
||||
@ -16,7 +16,7 @@ from .ourlang import (
|
||||
|
||||
Expression,
|
||||
BinaryOp,
|
||||
ConstantPrimitive, ConstantTuple,
|
||||
ConstantPrimitive, ConstantTuple, ConstantStruct,
|
||||
|
||||
FunctionCall, AccessStructMember, Subscript,
|
||||
StructDefinition, StructConstructor,
|
||||
@ -184,9 +184,9 @@ class OurVisitor:
|
||||
|
||||
def pre_visit_Module_AnnAssign(self, module: Module, node: ast.AnnAssign) -> ModuleConstantDef:
|
||||
if not isinstance(node.target, ast.Name):
|
||||
_raise_static_error(node, 'Must be name')
|
||||
_raise_static_error(node.target, 'Must be name')
|
||||
if not isinstance(node.target.ctx, ast.Store):
|
||||
_raise_static_error(node, 'Must be load context')
|
||||
_raise_static_error(node.target, 'Must be store context')
|
||||
|
||||
if isinstance(node.value, ast.Constant):
|
||||
type3 = self.visit_type(module, node.annotation)
|
||||
@ -220,6 +220,45 @@ class OurVisitor:
|
||||
data_block,
|
||||
)
|
||||
|
||||
if isinstance(node.value, ast.Call):
|
||||
# Struct constant
|
||||
# Stored in memory like a tuple, so much of the code is the same
|
||||
|
||||
if not isinstance(node.value.func, ast.Name):
|
||||
_raise_static_error(node.value.func, 'Must be name')
|
||||
if not isinstance(node.value.func.ctx, ast.Load):
|
||||
_raise_static_error(node.value.func, 'Must be load context')
|
||||
|
||||
if not node.value.func.id in module.struct_definitions:
|
||||
_raise_static_error(node.value.func, 'Undefined struct')
|
||||
|
||||
if node.value.keywords:
|
||||
_raise_static_error(node.value.func, 'Cannot use keywords')
|
||||
|
||||
if not isinstance(node.annotation, ast.Name):
|
||||
_raise_static_error(node.annotation, 'Must be name')
|
||||
|
||||
struct_data = [
|
||||
self.visit_Module_Constant(module, arg_node)
|
||||
for arg_node in node.value.args
|
||||
if isinstance(arg_node, ast.Constant)
|
||||
]
|
||||
if len(node.value.args) != len(struct_data):
|
||||
_raise_static_error(node, 'Struct arguments must be constants')
|
||||
|
||||
# Allocate the data
|
||||
data_block = ModuleDataBlock(struct_data)
|
||||
module.data.blocks.append(data_block)
|
||||
|
||||
# Then return the constant as a pointer
|
||||
return ModuleConstantDef(
|
||||
node.target.id,
|
||||
node.lineno,
|
||||
self.visit_type(module, node.annotation),
|
||||
ConstantStruct(node.value.func.id, struct_data),
|
||||
data_block,
|
||||
)
|
||||
|
||||
raise NotImplementedError('TODO: Broken after new typing system')
|
||||
|
||||
# if isinstance(exp_type, TypeTuple):
|
||||
@ -549,31 +588,22 @@ class OurVisitor:
|
||||
return result
|
||||
|
||||
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
|
||||
del module
|
||||
del function
|
||||
|
||||
if not isinstance(node.value, ast.Name):
|
||||
_raise_static_error(node, 'Must reference a name')
|
||||
|
||||
if not isinstance(node.ctx, ast.Load):
|
||||
_raise_static_error(node, 'Must be load context')
|
||||
|
||||
if not node.value.id in our_locals:
|
||||
_raise_static_error(node, f'Undefined variable {node.value.id}')
|
||||
varref = self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value)
|
||||
if not isinstance(varref, VariableReference):
|
||||
_raise_static_error(node.value, 'Must refer to variable')
|
||||
|
||||
param = our_locals[node.value.id]
|
||||
|
||||
node_typ = param.type3
|
||||
if not isinstance(node_typ, type3types.StructType3):
|
||||
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
|
||||
|
||||
member = node_typ.members.get(node.attr)
|
||||
if member is None:
|
||||
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
|
||||
if not isinstance(varref.variable.type3, type3types.StructType3):
|
||||
_raise_static_error(node.value, 'Must refer to struct')
|
||||
|
||||
return AccessStructMember(
|
||||
VariableReference(param),
|
||||
node_typ,
|
||||
varref,
|
||||
varref.variable.type3,
|
||||
node.attr,
|
||||
)
|
||||
|
||||
|
||||
@ -203,9 +203,14 @@ class LiteralFitsConstraint(ConstraintBase):
|
||||
__slots__ = ('type3', 'literal', )
|
||||
|
||||
type3: types.Type3OrPlaceholder
|
||||
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple]
|
||||
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct]
|
||||
|
||||
def __init__(self, type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple], comment: Optional[str] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
type3: types.Type3OrPlaceholder,
|
||||
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct],
|
||||
comment: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__(comment=comment)
|
||||
|
||||
self.type3 = type3
|
||||
@ -226,7 +231,7 @@ class LiteralFitsConstraint(ConstraintBase):
|
||||
'f64': None,
|
||||
}
|
||||
|
||||
def _check(type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple]) -> CheckResult:
|
||||
def _check(type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct]) -> CheckResult:
|
||||
if isinstance(type3, types.PlaceholderForType):
|
||||
if type3 not in smap:
|
||||
return RequireTypeSubstitutes()
|
||||
@ -275,7 +280,23 @@ class LiteralFitsConstraint(ConstraintBase):
|
||||
|
||||
return None
|
||||
|
||||
raise NotImplementedError
|
||||
if isinstance(type3, types.StructType3):
|
||||
if not isinstance(literal, ourlang.ConstantStruct):
|
||||
return Error('Must be struct')
|
||||
|
||||
assert isinstance(val, list) # type hint
|
||||
|
||||
if len(type3.members) != len(val):
|
||||
return Error('Struct element count mismatch')
|
||||
|
||||
for elt_typ, elt_lit in zip(type3.members.values(), val):
|
||||
res = _check(elt_typ, elt_lit)
|
||||
if res is not None:
|
||||
return res
|
||||
|
||||
return None
|
||||
|
||||
raise NotImplementedError(type3, literal)
|
||||
|
||||
return _check(self.type3, self.literal)
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ def phasm_type3_generate_constraints(inp: ourlang.Module) -> List[ConstraintBase
|
||||
return [*module(ctx, inp)]
|
||||
|
||||
def constant(ctx: Context, inp: ourlang.Constant) -> Generator[ConstraintBase, None, None]:
|
||||
if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, )):
|
||||
if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct)):
|
||||
yield LiteralFitsConstraint(inp.type3, inp)
|
||||
return
|
||||
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from phasm.exceptions import StaticError
|
||||
from phasm.type3.entry import Type3Exception
|
||||
|
||||
from phasm.parser import phasm_parse
|
||||
|
||||
from ..constants import (
|
||||
ALL_INT_TYPES
|
||||
ALL_INT_TYPES, TYPE_MAP
|
||||
)
|
||||
from ..helpers import Suite
|
||||
|
||||
@ -88,6 +85,18 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
|
||||
|
||||
assert 545 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_type_mismatch_arg_module_constant():
|
||||
code_py = """
|
||||
class Struct:
|
||||
param: f32
|
||||
|
||||
STRUCT: Struct = Struct(1)
|
||||
"""
|
||||
|
||||
with pytest.raises(Type3Exception, match='todo'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64'])
|
||||
def test_type_mismatch_struct_member(type_):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user