Compare commits

...

2 Commits

Author SHA1 Message Date
Johan B.W. de Vries
7ce3f0f11c Started on struct constants 2022-12-18 15:05:41 +01:00
Johan B.W. de Vries
d18f1c6956 Re-implemented allocation calculations 2022-12-18 14:31:17 +01:00
7 changed files with 136 additions and 30 deletions

View File

@ -67,6 +67,12 @@ def expression(inp: ourlang.Expression) -> str:
for x in inp.value
) + ', )'
if isinstance(inp, ourlang.ConstantStruct):
return inp.struct_name + '(' + ', '.join(
expression(x)
for x in inp.value
) + ')'
if isinstance(inp, ourlang.VariableReference):
return str(inp.variable.name)

View File

@ -771,7 +771,30 @@ def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstruc
wgn.local.get(tmp_var)
def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) -> int:
return 0 # FIXME: Stub
if typ is type3types.u8:
return 1
if typ is type3types.u32 or typ is type3types.i32 or typ is type3types.f32:
return 4
if typ is type3types.u64 or typ is type3types.i64 or typ is type3types.f64:
return 8
if isinstance(typ, type3types.StructType3):
return sum(
_calculate_alloc_size(x)
for x in typ.members.values()
)
raise NotImplementedError(_calculate_alloc_size, typ)
def _calculate_member_offset(struct_type3: type3types.StructType3, member: str) -> int:
return 0 # FIXME: Stub
result = 0
for mem, memtyp in struct_type3.members.items():
if member == mem:
return result
result += _calculate_alloc_size(memtyp)
raise Exception(f'{member} not in {struct_type3}')

View File

@ -62,6 +62,23 @@ class ConstantTuple(Constant):
def __repr__(self) -> str:
return f'ConstantTuple({repr(self.value)})'
class ConstantStruct(Constant):
"""
A Struct constant value expression within a statement
"""
__slots__ = ('struct_name', 'value', )
struct_name: str
value: List[ConstantPrimitive]
def __init__(self, struct_name: str, value: List[ConstantPrimitive]) -> None: # FIXME: Struct of structs?
super().__init__()
self.struct_name = struct_name
self.value = value
def __repr__(self) -> str:
return f'ConstantStruct({repr(self.struct_name)}, {repr(self.value)})'
class VariableReference(Expression):
"""
An variable reference expression within a statement

View File

@ -16,7 +16,7 @@ from .ourlang import (
Expression,
BinaryOp,
ConstantPrimitive, ConstantTuple,
ConstantPrimitive, ConstantTuple, ConstantStruct,
FunctionCall, AccessStructMember, Subscript,
StructDefinition, StructConstructor,
@ -184,9 +184,9 @@ class OurVisitor:
def pre_visit_Module_AnnAssign(self, module: Module, node: ast.AnnAssign) -> ModuleConstantDef:
if not isinstance(node.target, ast.Name):
_raise_static_error(node, 'Must be name')
_raise_static_error(node.target, 'Must be name')
if not isinstance(node.target.ctx, ast.Store):
_raise_static_error(node, 'Must be load context')
_raise_static_error(node.target, 'Must be store context')
if isinstance(node.value, ast.Constant):
type3 = self.visit_type(module, node.annotation)
@ -220,6 +220,45 @@ class OurVisitor:
data_block,
)
if isinstance(node.value, ast.Call):
# Struct constant
# Stored in memory like a tuple, so much of the code is the same
if not isinstance(node.value.func, ast.Name):
_raise_static_error(node.value.func, 'Must be name')
if not isinstance(node.value.func.ctx, ast.Load):
_raise_static_error(node.value.func, 'Must be load context')
if not node.value.func.id in module.struct_definitions:
_raise_static_error(node.value.func, 'Undefined struct')
if node.value.keywords:
_raise_static_error(node.value.func, 'Cannot use keywords')
if not isinstance(node.annotation, ast.Name):
_raise_static_error(node.annotation, 'Must be name')
struct_data = [
self.visit_Module_Constant(module, arg_node)
for arg_node in node.value.args
if isinstance(arg_node, ast.Constant)
]
if len(node.value.args) != len(struct_data):
_raise_static_error(node, 'Struct arguments must be constants')
# Allocate the data
data_block = ModuleDataBlock(struct_data)
module.data.blocks.append(data_block)
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
self.visit_type(module, node.annotation),
ConstantStruct(node.value.func.id, struct_data),
data_block,
)
raise NotImplementedError('TODO: Broken after new typing system')
# if isinstance(exp_type, TypeTuple):
@ -549,31 +588,22 @@ class OurVisitor:
return result
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
del module
del function
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}')
varref = self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value)
if not isinstance(varref, VariableReference):
_raise_static_error(node.value, 'Must refer to variable')
param = our_locals[node.value.id]
node_typ = param.type3
if not isinstance(node_typ, type3types.StructType3):
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
member = node_typ.members.get(node.attr)
if member is None:
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
if not isinstance(varref.variable.type3, type3types.StructType3):
_raise_static_error(node.value, 'Must refer to struct')
return AccessStructMember(
VariableReference(param),
node_typ,
varref,
varref.variable.type3,
node.attr,
)

View File

@ -203,9 +203,14 @@ class LiteralFitsConstraint(ConstraintBase):
__slots__ = ('type3', 'literal', )
type3: types.Type3OrPlaceholder
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple]
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct]
def __init__(self, type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple], comment: Optional[str] = None) -> None:
def __init__(
self,
type3: types.Type3OrPlaceholder,
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct],
comment: Optional[str] = None,
) -> None:
super().__init__(comment=comment)
self.type3 = type3
@ -226,7 +231,7 @@ class LiteralFitsConstraint(ConstraintBase):
'f64': None,
}
def _check(type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple]) -> CheckResult:
def _check(type3: types.Type3OrPlaceholder, literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct]) -> CheckResult:
if isinstance(type3, types.PlaceholderForType):
if type3 not in smap:
return RequireTypeSubstitutes()
@ -275,7 +280,23 @@ class LiteralFitsConstraint(ConstraintBase):
return None
raise NotImplementedError
if isinstance(type3, types.StructType3):
if not isinstance(literal, ourlang.ConstantStruct):
return Error('Must be struct')
assert isinstance(val, list) # type hint
if len(type3.members) != len(val):
return Error('Struct element count mismatch')
for elt_typ, elt_lit in zip(type3.members.values(), val):
res = _check(elt_typ, elt_lit)
if res is not None:
return res
return None
raise NotImplementedError(type3, literal)
return _check(self.type3, self.literal)

View File

@ -23,7 +23,7 @@ def phasm_type3_generate_constraints(inp: ourlang.Module) -> List[ConstraintBase
return [*module(ctx, inp)]
def constant(ctx: Context, inp: ourlang.Constant) -> Generator[ConstraintBase, None, None]:
if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, )):
if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantTuple, ourlang.ConstantStruct)):
yield LiteralFitsConstraint(inp.type3, inp)
return

View File

@ -1,12 +1,9 @@
import pytest
from phasm.exceptions import StaticError
from phasm.type3.entry import Type3Exception
from phasm.parser import phasm_parse
from ..constants import (
ALL_INT_TYPES
ALL_INT_TYPES, TYPE_MAP
)
from ..helpers import Suite
@ -88,6 +85,18 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
assert 545 == result.returned_value
@pytest.mark.integration_test
def test_type_mismatch_arg_module_constant():
code_py = """
class Struct:
param: f32
STRUCT: Struct = Struct(1)
"""
with pytest.raises(Type3Exception, match='todo'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64'])
def test_type_mismatch_struct_member(type_):