Started on struct constants
This commit is contained in:
parent
d18f1c6956
commit
7ce3f0f11c
@ -67,6 +67,12 @@ def expression(inp: ourlang.Expression) -> str:
|
|||||||
for x in inp.value
|
for x in inp.value
|
||||||
) + ', )'
|
) + ', )'
|
||||||
|
|
||||||
|
if isinstance(inp, ourlang.ConstantStruct):
|
||||||
|
return inp.struct_name + '(' + ', '.join(
|
||||||
|
expression(x)
|
||||||
|
for x in inp.value
|
||||||
|
) + ')'
|
||||||
|
|
||||||
if isinstance(inp, ourlang.VariableReference):
|
if isinstance(inp, ourlang.VariableReference):
|
||||||
return str(inp.variable.name)
|
return str(inp.variable.name)
|
||||||
|
|
||||||
|
|||||||
@ -62,6 +62,23 @@ class ConstantTuple(Constant):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'ConstantTuple({repr(self.value)})'
|
return f'ConstantTuple({repr(self.value)})'
|
||||||
|
|
||||||
|
class ConstantStruct(Constant):
|
||||||
|
"""
|
||||||
|
A Struct constant value expression within a statement
|
||||||
|
"""
|
||||||
|
__slots__ = ('struct_name', 'value', )
|
||||||
|
|
||||||
|
struct_name: str
|
||||||
|
value: List[ConstantPrimitive]
|
||||||
|
|
||||||
|
def __init__(self, struct_name: str, value: List[ConstantPrimitive]) -> None: # FIXME: Struct of structs?
|
||||||
|
super().__init__()
|
||||||
|
self.struct_name = struct_name
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'ConstantStruct({repr(self.struct_name)}, {repr(self.value)})'
|
||||||
|
|
||||||
class VariableReference(Expression):
|
class VariableReference(Expression):
|
||||||
"""
|
"""
|
||||||
An variable reference expression within a statement
|
An variable reference expression within a statement
|
||||||
|
|||||||
@ -16,7 +16,7 @@ from .ourlang import (
|
|||||||
|
|
||||||
Expression,
|
Expression,
|
||||||
BinaryOp,
|
BinaryOp,
|
||||||
ConstantPrimitive, ConstantTuple,
|
ConstantPrimitive, ConstantTuple, ConstantStruct,
|
||||||
|
|
||||||
FunctionCall, AccessStructMember, Subscript,
|
FunctionCall, AccessStructMember, Subscript,
|
||||||
StructDefinition, StructConstructor,
|
StructDefinition, StructConstructor,
|
||||||
@ -184,9 +184,9 @@ class OurVisitor:
|
|||||||
|
|
||||||
def pre_visit_Module_AnnAssign(self, module: Module, node: ast.AnnAssign) -> ModuleConstantDef:
|
def pre_visit_Module_AnnAssign(self, module: Module, node: ast.AnnAssign) -> ModuleConstantDef:
|
||||||
if not isinstance(node.target, ast.Name):
|
if not isinstance(node.target, ast.Name):
|
||||||
_raise_static_error(node, 'Must be name')
|
_raise_static_error(node.target, 'Must be name')
|
||||||
if not isinstance(node.target.ctx, ast.Store):
|
if not isinstance(node.target.ctx, ast.Store):
|
||||||
_raise_static_error(node, 'Must be load context')
|
_raise_static_error(node.target, 'Must be store context')
|
||||||
|
|
||||||
if isinstance(node.value, ast.Constant):
|
if isinstance(node.value, ast.Constant):
|
||||||
type3 = self.visit_type(module, node.annotation)
|
type3 = self.visit_type(module, node.annotation)
|
||||||
@ -220,6 +220,45 @@ class OurVisitor:
|
|||||||
data_block,
|
data_block,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if isinstance(node.value, ast.Call):
|
||||||
|
# Struct constant
|
||||||
|
# Stored in memory like a tuple, so much of the code is the same
|
||||||
|
|
||||||
|
if not isinstance(node.value.func, ast.Name):
|
||||||
|
_raise_static_error(node.value.func, 'Must be name')
|
||||||
|
if not isinstance(node.value.func.ctx, ast.Load):
|
||||||
|
_raise_static_error(node.value.func, 'Must be load context')
|
||||||
|
|
||||||
|
if not node.value.func.id in module.struct_definitions:
|
||||||
|
_raise_static_error(node.value.func, 'Undefined struct')
|
||||||
|
|
||||||
|
if node.value.keywords:
|
||||||
|
_raise_static_error(node.value.func, 'Cannot use keywords')
|
||||||
|
|
||||||
|
if not isinstance(node.annotation, ast.Name):
|
||||||
|
_raise_static_error(node.annotation, 'Must be name')
|
||||||
|
|
||||||
|
struct_data = [
|
||||||
|
self.visit_Module_Constant(module, arg_node)
|
||||||
|
for arg_node in node.value.args
|
||||||
|
if isinstance(arg_node, ast.Constant)
|
||||||
|
]
|
||||||
|
if len(node.value.args) != len(struct_data):
|
||||||
|
_raise_static_error(node, 'Struct arguments must be constants')
|
||||||
|
|
||||||
|
# Allocate the data
|
||||||
|
data_block = ModuleDataBlock(struct_data)
|
||||||
|
module.data.blocks.append(data_block)
|
||||||
|
|
||||||
|
# Then return the constant as a pointer
|
||||||
|
return ModuleConstantDef(
|
||||||
|
node.target.id,
|
||||||
|
node.lineno,
|
||||||
|
self.visit_type(module, node.annotation),
|
||||||
|
ConstantStruct(node.value.func.id, struct_data),
|
||||||
|
data_block,
|
||||||
|
)
|
||||||
|
|
||||||
raise NotImplementedError('TODO: Broken after new typing system')
|
raise NotImplementedError('TODO: Broken after new typing system')
|
||||||
|
|
||||||
# if isinstance(exp_type, TypeTuple):
|
# if isinstance(exp_type, TypeTuple):
|
||||||
@ -549,31 +588,22 @@ class OurVisitor:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
|
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
|
||||||
del module
|
|
||||||
del function
|
|
||||||
|
|
||||||
if not isinstance(node.value, ast.Name):
|
if not isinstance(node.value, ast.Name):
|
||||||
_raise_static_error(node, 'Must reference a name')
|
_raise_static_error(node, 'Must reference a name')
|
||||||
|
|
||||||
if not isinstance(node.ctx, ast.Load):
|
if not isinstance(node.ctx, ast.Load):
|
||||||
_raise_static_error(node, 'Must be load context')
|
_raise_static_error(node, 'Must be load context')
|
||||||
|
|
||||||
if not node.value.id in our_locals:
|
varref = self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value)
|
||||||
_raise_static_error(node, f'Undefined variable {node.value.id}')
|
if not isinstance(varref, VariableReference):
|
||||||
|
_raise_static_error(node.value, 'Must refer to variable')
|
||||||
|
|
||||||
param = our_locals[node.value.id]
|
if not isinstance(varref.variable.type3, type3types.StructType3):
|
||||||
|
_raise_static_error(node.value, 'Must refer to struct')
|
||||||
node_typ = param.type3
|
|
||||||
if not isinstance(node_typ, type3types.StructType3):
|
|
||||||
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
|
|
||||||
|
|
||||||
member = node_typ.members.get(node.attr)
|
|
||||||
if member is None:
|
|
||||||
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
|
|
||||||
|
|
||||||
return AccessStructMember(
|
return AccessStructMember(
|
||||||
VariableReference(param),
|
varref,
|
||||||
node_typ,
|
varref.variable.type3,
|
||||||
node.attr,
|
node.attr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -203,9 +203,14 @@ class LiteralFitsConstraint(ConstraintBase):
|
|||||||
__slots__ = ('type3', 'literal', )
|
__slots__ = ('type3', 'literal', )
|
||||||
|
|
||||||
type3: types.Type3OrPlaceholder
|
type3: types.Type3OrPlaceholder
|
||||||
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple]
|
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct]
|
||||||
|
|
||||||
def __init__(self, type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple], comment: Optional[str] = None) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
type3: types.Type3OrPlaceholder,
|
||||||
|
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct],
|
||||||
|
comment: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
super().__init__(comment=comment)
|
super().__init__(comment=comment)
|
||||||
|
|
||||||
self.type3 = type3
|
self.type3 = type3
|
||||||
@ -226,7 +231,7 @@ class LiteralFitsConstraint(ConstraintBase):
|
|||||||
'f64': None,
|
'f64': None,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _check(type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple]) -> CheckResult:
|
def _check(type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct]) -> CheckResult:
|
||||||
if isinstance(type3, types.PlaceholderForType):
|
if isinstance(type3, types.PlaceholderForType):
|
||||||
if type3 not in smap:
|
if type3 not in smap:
|
||||||
return RequireTypeSubstitutes()
|
return RequireTypeSubstitutes()
|
||||||
@ -275,7 +280,23 @@ class LiteralFitsConstraint(ConstraintBase):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
raise NotImplementedError
|
if isinstance(type3, types.StructType3):
|
||||||
|
if not isinstance(literal, ourlang.ConstantStruct):
|
||||||
|
return Error('Must be struct')
|
||||||
|
|
||||||
|
assert isinstance(val, list) # type hint
|
||||||
|
|
||||||
|
if len(type3.members) != len(val):
|
||||||
|
return Error('Struct element count mismatch')
|
||||||
|
|
||||||
|
for elt_typ, elt_lit in zip(type3.members.values(), val):
|
||||||
|
res = _check(elt_typ, elt_lit)
|
||||||
|
if res is not None:
|
||||||
|
return res
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
raise NotImplementedError(type3, literal)
|
||||||
|
|
||||||
return _check(self.type3, self.literal)
|
return _check(self.type3, self.literal)
|
||||||
|
|
||||||
|
|||||||
@ -23,7 +23,7 @@ def phasm_type3_generate_constraints(inp: ourlang.Module) -> List[ConstraintBase
|
|||||||
return [*module(ctx, inp)]
|
return [*module(ctx, inp)]
|
||||||
|
|
||||||
def constant(ctx: Context, inp: ourlang.Constant) -> Generator[ConstraintBase, None, None]:
|
def constant(ctx: Context, inp: ourlang.Constant) -> Generator[ConstraintBase, None, None]:
|
||||||
if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, )):
|
if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct)):
|
||||||
yield LiteralFitsConstraint(inp.type3, inp)
|
yield LiteralFitsConstraint(inp.type3, inp)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from phasm.exceptions import StaticError
|
|
||||||
from phasm.type3.entry import Type3Exception
|
from phasm.type3.entry import Type3Exception
|
||||||
|
|
||||||
from phasm.parser import phasm_parse
|
|
||||||
|
|
||||||
from ..constants import (
|
from ..constants import (
|
||||||
ALL_INT_TYPES
|
ALL_INT_TYPES, TYPE_MAP
|
||||||
)
|
)
|
||||||
from ..helpers import Suite
|
from ..helpers import Suite
|
||||||
|
|
||||||
@ -88,6 +85,18 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
|
|||||||
|
|
||||||
assert 545 == result.returned_value
|
assert 545 == result.returned_value
|
||||||
|
|
||||||
|
@pytest.mark.integration_test
|
||||||
|
def test_type_mismatch_arg_module_constant():
|
||||||
|
code_py = """
|
||||||
|
class Struct:
|
||||||
|
param: f32
|
||||||
|
|
||||||
|
STRUCT: Struct = Struct(1)
|
||||||
|
"""
|
||||||
|
|
||||||
|
with pytest.raises(Type3Exception, match='todo'):
|
||||||
|
Suite(code_py).run_code()
|
||||||
|
|
||||||
@pytest.mark.integration_test
|
@pytest.mark.integration_test
|
||||||
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64'])
|
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64'])
|
||||||
def test_type_mismatch_struct_member(type_):
|
def test_type_mismatch_struct_member(type_):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user