diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 04ee043..500b878 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -95,26 +95,23 @@ def expression(inp: ourlang.Expression) -> str: # return f'({args}, )' return f'{inp.function.name}({args})' - - if isinstance(inp, ourlang.AccessBytesIndex): - return f'{expression(inp.varref)}[{expression(inp.index)}]' - - if isinstance(inp, ourlang.AccessStructMember): - return f'{expression(inp.varref)}.{inp.member.name}' - - if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): - if isinstance(inp.member, ourlang.Expression): - return f'{expression(inp.varref)}[{expression(inp.member)}]' - - return f'{expression(inp.varref)}[{inp.member.idx}]' + # + # if isinstance(inp, ourlang.AccessBytesIndex): + # return f'{expression(inp.varref)}[{expression(inp.index)}]' + # + # if isinstance(inp, ourlang.AccessStructMember): + # return f'{expression(inp.varref)}.{inp.member.name}' + # + # if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): + # if isinstance(inp.member, ourlang.Expression): + # return f'{expression(inp.varref)}[{expression(inp.member)}]' + # + # return f'{expression(inp.varref)}[{inp.member.idx}]' if isinstance(inp, ourlang.Fold): fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' - if isinstance(inp, ourlang.ModuleConstantReference): - return inp.definition.name - raise NotImplementedError(expression, inp) def statement(inp: ourlang.Statement) -> Statements: diff --git a/phasm/compiler.py b/phasm/compiler.py index 18a548a..826f21e 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -156,8 +156,39 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: raise NotImplementedError(f'Constants with type {stp}') if isinstance(inp, ourlang.VariableReference): - wgn.add_statement('local.get', '${}'.format(inp.variable.name)) - return + if isinstance(inp.variable, ourlang.FunctionParam): + wgn.add_statement('local.get', '${}'.format(inp.variable.name)) + return + + if isinstance(inp.variable, ourlang.ModuleConstantDef): + # FIXME: Tuple / Static Array broken after new type system + # if isinstance(inp.type, typing.TypeTuple): + # assert isinstance(inp.definition.constant, ourlang.ConstantTuple) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return + # + # if isinstance(inp.type, typing.TypeStaticArray): + # assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return + + assert inp.variable.data_block is None, 'Primitives are not memory stored' + + assert inp.variable.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.variable.type_var) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, inp.type_var) + + expression(wgn, inp.variable.constant) + return + + raise NotImplementedError(expression, inp.variable) if isinstance(inp, ourlang.BinaryOp): expression(wgn, inp.left) @@ -301,34 +332,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: expression_fold(wgn, inp) return - if isinstance(inp, ourlang.ModuleConstantReference): - # FIXME: Tuple / Static Array broken after new type system - # if isinstance(inp.type, typing.TypeTuple): - # assert isinstance(inp.definition.constant, ourlang.ConstantTuple) - # assert inp.definition.data_block is not None, 'Combined values are memory stored' - # assert inp.definition.data_block.address is not None, 'Value not allocated' - # wgn.i32.const(inp.definition.data_block.address) - # return - # - # if isinstance(inp.type, typing.TypeStaticArray): - # assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) - # assert inp.definition.data_block is not None, 'Combined values are memory stored' - # assert inp.definition.data_block.address is not None, 'Value not allocated' - # wgn.i32.const(inp.definition.data_block.address) - # return - - assert inp.definition.data_block is None, 'Primitives are not memory stored' - - assert inp.type_var is not None, typing.ASSERTION_ERROR - mtyp = typing.simplify(inp.type_var) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.type_var) - - expression(wgn, inp.definition.constant) - return - raise NotImplementedError(expression, inp) def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: @@ -566,38 +569,46 @@ def module_data(inp: ourlang.ModuleData) -> bytes: data_list: List[bytes] = [] - raise NotImplementedError('Broken after new type system') - for constant in block.data: - if isinstance(constant, ourlang.ConstantUInt8): + assert constant.type_var is not None + mtyp = typing.simplify(constant.type_var) + + if mtyp == 'u8': + assert isinstance(constant.value, int) data_list.append(module_data_u8(constant.value)) continue - if isinstance(constant, ourlang.ConstantUInt32): + if mtyp == 'u32': + assert isinstance(constant.value, int) data_list.append(module_data_u32(constant.value)) continue - if isinstance(constant, ourlang.ConstantUInt64): + if mtyp == 'u64': + assert isinstance(constant.value, int) data_list.append(module_data_u64(constant.value)) continue - if isinstance(constant, ourlang.ConstantInt32): + if mtyp == 'i32': + assert isinstance(constant.value, int) data_list.append(module_data_i32(constant.value)) continue - if isinstance(constant, ourlang.ConstantInt64): + if mtyp == 'i64': + assert isinstance(constant.value, int) data_list.append(module_data_i64(constant.value)) continue - if isinstance(constant, ourlang.ConstantFloat32): + if mtyp == 'f32': + assert isinstance(constant.value, float) data_list.append(module_data_f32(constant.value)) continue - if isinstance(constant, ourlang.ConstantFloat64): + if mtyp == 'f64': + assert isinstance(constant.value, float) data_list.append(module_data_f64(constant.value)) continue - raise NotImplementedError(constant) + raise NotImplementedError(constant, mtyp) block_data = b''.join(data_list) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 9e16e4b..9ffe2c4 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -11,10 +11,7 @@ WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) from .typing import ( - TypeBytes, - TypeTuple, TypeTupleMember, - TypeStaticArray, TypeStaticArrayMember, - TypeStruct, TypeStructMember, + TypeStruct, TypeVar, ) @@ -78,9 +75,9 @@ class VariableReference(Expression): """ __slots__ = ('variable', ) - variable: 'FunctionParam' # also possibly local + variable: Union['ModuleConstantDef', 'FunctionParam'] # also possibly local - def __init__(self, variable: 'FunctionParam') -> None: + def __init__(self, variable: Union['ModuleConstantDef', 'FunctionParam']) -> None: super().__init__() self.variable = variable @@ -131,9 +128,10 @@ class FunctionCall(Expression): self.function = function self.arguments = [] -class AccessBytesIndex(Expression): +class Subscript(Expression): """ - Access a bytes index for reading + A subscript, for example to refer to a static array or tuple + by index """ __slots__ = ('varref', 'index', ) @@ -146,53 +144,6 @@ class AccessBytesIndex(Expression): self.varref = varref self.index = index -class AccessStructMember(Expression): - """ - Access a struct member for reading of writing - """ - __slots__ = ('varref', 'member', ) - - varref: VariableReference - member: TypeStructMember - - def __init__(self, varref: VariableReference, member: TypeStructMember) -> None: - super().__init__() - - self.varref = varref - self.member = member - -class AccessTupleMember(Expression): - """ - Access a tuple member for reading of writing - """ - __slots__ = ('varref', 'member', ) - - varref: VariableReference - member: TypeTupleMember - - def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None: - super().__init__() - - self.varref = varref - self.member = member - -class AccessStaticArrayMember(Expression): - """ - Access a tuple member for reading of writing - """ - __slots__ = ('varref', 'static_array', 'member', ) - - varref: Union['ModuleConstantReference', VariableReference] - static_array: TypeStaticArray - member: Union[Expression, TypeStaticArrayMember] - - def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None: - super().__init__() - - self.varref = varref - self.static_array = static_array - self.member = member - class Fold(Expression): """ A (left or right) fold @@ -223,18 +174,6 @@ class Fold(Expression): self.base = base self.iter = iter_ -class ModuleConstantReference(Expression): - """ - An reference to a module constant expression within a statement - """ - __slots__ = ('definition', ) - - definition: 'ModuleConstantDef' - - def __init__(self, definition: 'ModuleConstantDef') -> None: - super().__init__() - self.definition = definition - class Statement: """ A statement within a function diff --git a/phasm/parser.py b/phasm/parser.py index 6fe96fc..00c4c16 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -22,15 +22,14 @@ from .ourlang import ( Function, Expression, - AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, BinaryOp, ConstantPrimitive, ConstantTuple, ConstantStaticArray, - FunctionCall, + FunctionCall, Subscript, # StructConstructor, TupleConstructor, UnaryOp, VariableReference, - Fold, ModuleConstantReference, + Fold, Statement, StatementIf, StatementPass, StatementReturn, @@ -206,6 +205,27 @@ class OurVisitor: None, ) + if isinstance(node.value, ast.Tuple): + tuple_data = [ + self.visit_Module_Constant(module, arg_node) + for arg_node in node.value.elts + if isinstance(arg_node, ast.Constant) + ] + if len(node.value.elts) != len(tuple_data): + _raise_static_error(node, 'Tuple arguments must be constants') + + # Allocate the data + data_block = ModuleDataBlock(tuple_data) + module.data.blocks.append(data_block) + + # Then return the constant as a pointer + return ModuleConstantDef( + node.target.id, + node.lineno, + ConstantTuple(tuple_data), + data_block, + ) + raise NotImplementedError('TODO: Broken after new typing system') # if isinstance(exp_type, TypeTuple): @@ -416,7 +436,7 @@ class OurVisitor: if node.id in module.constant_defs: cdef = module.constant_defs[node.id] - return ModuleConstantReference(cdef) + return VariableReference(cdef) _raise_static_error(node, f'Undefined variable {node.id}') @@ -454,13 +474,13 @@ class OurVisitor: if not isinstance(node.func.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.func.id in module.structs: - raise NotImplementedError('TODO: Broken after new type system') + # if node.func.id in module.structs: + # raise NotImplementedError('TODO: Broken after new type system') # struct = module.structs[node.func.id] # struct_constructor = StructConstructor(struct) # # func = module.functions[struct_constructor.name] - elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') @@ -533,61 +553,59 @@ class OurVisitor: def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: raise NotImplementedError('Broken after new type system') - del module - del function - - if not isinstance(node.value, ast.Name): - _raise_static_error(node, 'Must reference a name') - - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if not node.value.id in our_locals: - _raise_static_error(node, f'Undefined variable {node.value.id}') - - param = our_locals[node.value.id] - - node_typ = param.type - if not isinstance(node_typ, TypeStruct): - _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') - - member = node_typ.get_member(node.attr) - if member is None: - _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') - - return AccessStructMember( - VariableReference(param), - member, - ) - - def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: - raise NotImplementedError('TODO: Broken after new type system') - + # del module + # del function + # # if not isinstance(node.value, ast.Name): # _raise_static_error(node, 'Must reference a name') # - # if not isinstance(node.slice, ast.Index): - # _raise_static_error(node, 'Must subscript using an index') - # # if not isinstance(node.ctx, ast.Load): # _raise_static_error(node, 'Must be load context') # - # varref: Union[ModuleConstantReference, VariableReference] - # if node.value.id in our_locals: - # param = our_locals[node.value.id] - # node_typ = param.type - # varref = VariableReference(param) - # elif node.value.id in module.constant_defs: - # constant_def = module.constant_defs[node.value.id] - # node_typ = constant_def.type - # varref = ModuleConstantReference(constant_def) - # else: + # if not node.value.id in our_locals: # _raise_static_error(node, f'Undefined variable {node.value.id}') # - # slice_expr = self.visit_Module_FunctionDef_expr( - # module, function, our_locals, node.slice.value, - # ) + # param = our_locals[node.value.id] # + # node_typ = param.type + # if not isinstance(node_typ, TypeStruct): + # _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') + # + # member = node_typ.get_member(node.attr) + # if member is None: + # _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') + # + # return AccessStructMember( + # VariableReference(param), + # member, + # ) + + def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: + if not isinstance(node.value, ast.Name): + _raise_static_error(node, 'Must reference a name') + + if not isinstance(node.slice, ast.Index): + _raise_static_error(node, 'Must subscript using an index') + + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + varref: VariableReference + if node.value.id in our_locals: + param = our_locals[node.value.id] + varref = VariableReference(param) + elif node.value.id in module.constant_defs: + constant_def = module.constant_defs[node.value.id] + varref = VariableReference(constant_def) + else: + _raise_static_error(node, f'Undefined variable {node.value.id}') + + slice_expr = self.visit_Module_FunctionDef_expr( + module, function, our_locals, node.slice.value, + ) + + return Subscript(varref, slice_expr) + # if isinstance(node_typ, TypeBytes): # if isinstance(varref, ModuleConstantReference): # raise NotImplementedError(f'{node} from module constant') diff --git a/phasm/typer.py b/phasm/typer.py index 2b17f63..e57e043 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -62,7 +62,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return constant(ctx, inp) if isinstance(inp, ourlang.VariableReference): - assert inp.variable.type_var is not None, inp + assert inp.variable.type_var is not None return inp.variable.type_var if isinstance(inp, ourlang.UnaryOp): @@ -112,13 +112,6 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return inp.function.returns_type_var - if isinstance(inp, ourlang.ModuleConstantReference): - assert inp.definition.type_var is not None - - inp.type_var = inp.definition.type_var - - return inp.definition.type_var - raise NotImplementedError(expression, inp) def function(ctx: Context, inp: ourlang.Function) -> None: diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py index ced9277..68bfdd4 100644 --- a/tests/integration/test_lang/test_static_array.py +++ b/tests/integration/test_lang/test_static_array.py @@ -7,21 +7,18 @@ from ..helpers import Suite @pytest.mark.integration_test @pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) -def test_static_array_module_constant(type_): +def test_module_constant(type_): code_py = f""" CONSTANT: {type_}[3] = (24, 57, 80, ) @exported def testEntry() -> {type_}: - return helper(CONSTANT) - -def helper(array: {type_}[3]) -> {type_}: - return array[0] + array[1] + array[2] + return CONSTANT[0] """ result = Suite(code_py).run_code() - assert 161 == result.returned_value + assert 24 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test @@ -43,6 +40,25 @@ def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}: assert 161 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +def test_function_call(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT) + +def helper(array: {type_}[3]) -> {type_}: + return array[0] + array[1] + array[2] +""" + + result = Suite(code_py).run_code() + + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + @pytest.mark.integration_test def test_static_array_constant_too_few_values(): code_py = """