diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 4b38e32..1af79ca 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -3,7 +3,7 @@ This module generates source code based on the parsed AST It's intented to be a "any color, as long as it's black" kind of renderer """ -from typing import Generator +from typing import Generator, Optional from . import ourlang from . import typing @@ -16,63 +16,26 @@ def phasm_render(inp: ourlang.Module) -> str: Statements = Generator[str, None, None] -def type_(inp: typing.TypeBase) -> str: +def type_var(inp: Optional[typing.TypeVar]) -> str: """ - Render: Type (name) + Render: type's name """ - if isinstance(inp, typing.TypeNone): - return 'None' + assert inp is not None, typing.ASSERTION_ERROR - if isinstance(inp, typing.TypeBool): - return 'bool' + mtyp = typing.simplify(inp) + if mtyp is None: + raise NotImplementedError(f'Rendering type {inp}') - if isinstance(inp, typing.TypeUInt8): - return 'u8' - - if isinstance(inp, typing.TypeUInt32): - return 'u32' - - if isinstance(inp, typing.TypeUInt64): - return 'u64' - - if isinstance(inp, typing.TypeInt32): - return 'i32' - - if isinstance(inp, typing.TypeInt64): - return 'i64' - - if isinstance(inp, typing.TypeFloat32): - return 'f32' - - if isinstance(inp, typing.TypeFloat64): - return 'f64' - - if isinstance(inp, typing.TypeBytes): - return 'bytes' - - if isinstance(inp, typing.TypeTuple): - mems = ', '.join( - type_(x.type) - for x in inp.members - ) - - return f'({mems}, )' - - if isinstance(inp, typing.TypeStaticArray): - return f'{type_(inp.member_type)}[{len(inp.members)}]' - - if isinstance(inp, typing.TypeStruct): - return inp.name - - raise NotImplementedError(type_, inp) + return mtyp def struct_definition(inp: typing.TypeStruct) -> str: """ Render: TypeStruct's definition """ result = f'class {inp.name}:\n' - for mem in inp.members: - result += f' {mem.name}: {type_(mem.type)}\n' + for mem in inp.members: # TODO: Broken after new type system + raise NotImplementedError('Structs broken after new type system') + # result += f' {mem.name}: {type_(mem.type)}\n' return result @@ -80,40 +43,38 @@ def constant_definition(inp: ourlang.ModuleConstantDef) -> str: """ Render: Module Constant's definition """ - return f'{inp.name}: {type_(inp.type)} = {expression(inp.constant)}\n' + return f'{inp.name}: {type_var(inp.type_var)} = {expression(inp.constant)}\n' def expression(inp: ourlang.Expression) -> str: """ Render: A Phasm expression """ - if isinstance(inp, ( - ourlang.ConstantUInt8, ourlang.ConstantUInt32, ourlang.ConstantUInt64, - ourlang.ConstantInt32, ourlang.ConstantInt64, - )): - return str(inp.value) - - if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )): - # These might not round trip if the original constant + if isinstance(inp, ourlang.ConstantPrimitive): + # Floats might not round trip if the original constant # could not fit in the given float type return str(inp.value) - if isinstance(inp, (ourlang.ConstantTuple, ourlang.ConstantStaticArray, )): + if isinstance(inp, ourlang.ConstantTuple): return '(' + ', '.join( expression(x) for x in inp.value ) + ', )' if isinstance(inp, ourlang.VariableReference): - return str(inp.name) + return str(inp.variable.name) if isinstance(inp, ourlang.UnaryOp): if ( - inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS - or inp.operator in ourlang.WEBASSEMBLY_BUILDIN_BYTES_OPS): + inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS + or inp.operator in ourlang.WEBASSEMBLY_BUILTIN_BYTES_OPS): return f'{inp.operator}({expression(inp.right)})' if inp.operator == 'cast': - return f'{type_(inp.type)}({expression(inp.right)})' + mtyp = type_var(inp.type_var) + if mtyp is None: + raise NotImplementedError(f'Casting to type {inp.type_var}') + + return f'{mtyp}({expression(inp.right)})' return f'{inp.operator}{expression(inp.right)}' @@ -126,33 +87,29 @@ def expression(inp: ourlang.Expression) -> str: for arg in inp.arguments ) - if isinstance(inp.function, ourlang.StructConstructor): - return f'{inp.function.struct.name}({args})' - - if isinstance(inp.function, ourlang.TupleConstructor): - return f'({args}, )' + # TODO: Broken after new type system + # if isinstance(inp.function, ourlang.StructConstructor): + # return f'{inp.function.struct.name}({args})' + # + # if isinstance(inp.function, ourlang.TupleConstructor): + # return f'({args}, )' return f'{inp.function.name}({args})' - if isinstance(inp, ourlang.AccessBytesIndex): - return f'{expression(inp.varref)}[{expression(inp.index)}]' + if isinstance(inp, ourlang.Subscript): + varref = expression(inp.varref) + index = expression(inp.index) - if isinstance(inp, ourlang.AccessStructMember): - return f'{expression(inp.varref)}.{inp.member.name}' + return f'{varref}[{index}]' - if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): - if isinstance(inp.member, ourlang.Expression): - return f'{expression(inp.varref)}[{expression(inp.member)}]' - - return f'{expression(inp.varref)}[{inp.member.idx}]' + # TODO: Broken after new type system + # if isinstance(inp, ourlang.AccessStructMember): + # return f'{expression(inp.varref)}.{inp.member.name}' if isinstance(inp, ourlang.Fold): fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' - if isinstance(inp, ourlang.ModuleConstantReference): - return inp.definition.name - raise NotImplementedError(expression, inp) def statement(inp: ourlang.Statement) -> Statements: @@ -193,11 +150,11 @@ def function(inp: ourlang.Function) -> str: result += '@imported\n' args = ', '.join( - f'{x}: {type_(y)}' - for x, y in inp.posonlyargs + f'{p.name}: {type_var(p.type_var)}' + for p in inp.posonlyargs ) - result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' + result += f'def {inp.name}({args}) -> {type_var(inp.returns_type_var)}:\n' if inp.imported: result += ' pass\n' @@ -227,7 +184,7 @@ def module(inp: ourlang.Module) -> str: for func in inp.functions.values(): if func.lineno < 0: - # Buildin (-2) or auto generated (-1) + # Builtin (-2) or auto generated (-1) continue if result: diff --git a/phasm/compiler.py b/phasm/compiler.py index 4ba5d5a..a83c2cf 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -1,6 +1,8 @@ """ This module contains the code to convert parsed Ourlang into WebAssembly code """ +from typing import List, Optional + import struct from . import codestyle @@ -12,19 +14,6 @@ from .stdlib import alloc as stdlib_alloc from .stdlib import types as stdlib_types from .wasmgenerator import Generator as WasmGenerator -LOAD_STORE_TYPE_MAP = { - typing.TypeUInt8: 'i32', - typing.TypeUInt32: 'i32', - typing.TypeUInt64: 'i64', - typing.TypeInt32: 'i32', - typing.TypeInt64: 'i64', - typing.TypeFloat32: 'f32', - typing.TypeFloat64: 'f64', -} -""" -When generating code, we sometimes need to load or store simple values -""" - def phasm_compile(inp: ourlang.Module) -> wasm.Module: """ Public method for compiling a parsed Phasm module into @@ -32,42 +21,57 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module: """ return module(inp) -def type_(inp: typing.TypeBase) -> wasm.WasmType: +def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType: """ Compile: type - """ - if isinstance(inp, typing.TypeNone): - return wasm.WasmTypeNone() - if isinstance(inp, typing.TypeUInt8): + Types are used for example in WebAssembly function parameters + and return types. + """ + assert inp is not None, typing.ASSERTION_ERROR + + mtyp = typing.simplify(inp) + + if mtyp == 'u8': # WebAssembly has only support for 32 and 64 bits # So we need to store more memory per byte return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeUInt32): + if mtyp == 'u32': return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeUInt64): + if mtyp == 'u64': return wasm.WasmTypeInt64() - if isinstance(inp, typing.TypeInt32): + if mtyp == 'i32': return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeInt64): + if mtyp == 'i64': return wasm.WasmTypeInt64() - if isinstance(inp, typing.TypeFloat32): + if mtyp == 'f32': return wasm.WasmTypeFloat32() - if isinstance(inp, typing.TypeFloat64): + if mtyp == 'f64': return wasm.WasmTypeFloat64() - if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): - # Structs and tuples are passed as pointer + assert inp is not None, typing.ASSERTION_ERROR + tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive) + if tc_prim is None: + raise NotImplementedError(type_var, inp) + + if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + # StaticArray, Tuples and Structs are passed as pointer # And pointers are i32 return wasm.WasmTypeInt32() - raise NotImplementedError(type_, inp) + # TODO: Broken after new type system + # if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): + # # Structs and tuples are passed as pointer + # # And pointers are i32 + # return wasm.WasmTypeInt32() + + raise NotImplementedError(inp, mtyp) # Operators that work for i32, i64, f32, f64 OPERATOR_MAP = { @@ -81,8 +85,6 @@ U8_OPERATOR_MAP = { # Under the hood, this is an i32 # Implementing Right Shift XOR, OR, AND is fine since the 3 remaining # bytes stay zero after this operation - # Since it's unsigned an unsigned value, Logical or Arithmetic shift right - # are the same operation '>>': 'shr_u', '^': 'xor', '|': 'or', @@ -131,109 +133,159 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: """ Compile: Any expression """ - if isinstance(inp, ourlang.ConstantUInt8): - wgn.i32.const(inp.value) - return + if isinstance(inp, ourlang.ConstantPrimitive): + assert inp.type_var is not None, typing.ASSERTION_ERROR - if isinstance(inp, ourlang.ConstantUInt32): - wgn.i32.const(inp.value) - return + stp = typing.simplify(inp.type_var) + if stp is None: + raise NotImplementedError(f'Constants with type {inp.type_var}') - if isinstance(inp, ourlang.ConstantUInt64): - wgn.i64.const(inp.value) - return + if stp == 'u8': + # No native u8 type - treat as i32, with caution + assert isinstance(inp.value, int) + wgn.i32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantInt32): - wgn.i32.const(inp.value) - return + if stp in ('i32', 'u32'): + assert isinstance(inp.value, int) + wgn.i32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantInt64): - wgn.i64.const(inp.value) - return + if stp in ('i64', 'u64'): + assert isinstance(inp.value, int) + wgn.i64.const(inp.value) + return - if isinstance(inp, ourlang.ConstantFloat32): - wgn.f32.const(inp.value) - return + if stp == 'f32': + assert isinstance(inp.value, float) + wgn.f32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantFloat64): - wgn.f64.const(inp.value) - return + if stp == 'f64': + assert isinstance(inp.value, float) + wgn.f64.const(inp.value) + return + + raise NotImplementedError(f'Constants with type {stp}') if isinstance(inp, ourlang.VariableReference): - wgn.add_statement('local.get', '${}'.format(inp.name)) - return + if isinstance(inp.variable, ourlang.FunctionParam): + wgn.add_statement('local.get', '${}'.format(inp.variable.name)) + return + + if isinstance(inp.variable, ourlang.ModuleConstantDef): + assert inp.variable.type_var is not None, typing.ASSERTION_ERROR + tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive) + if tc_prim is None: + raise NotImplementedError(expression, inp, inp.variable.type_var) + + # TODO: Broken after new type system + # if isinstance(inp.type, typing.TypeTuple): + # assert isinstance(inp.definition.constant, ourlang.ConstantTuple) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return + # + + if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + assert inp.variable.data_block is not None, 'Combined values are memory stored' + assert inp.variable.data_block.address is not None, 'Value not allocated' + wgn.i32.const(inp.variable.data_block.address) + return + + assert inp.variable.data_block is None, 'Primitives are not memory stored' + + assert inp.variable.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.variable.type_var) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, inp.type_var) + + expression(wgn, inp.variable.constant) + return + + raise NotImplementedError(expression, inp.variable) if isinstance(inp, ourlang.BinaryOp): expression(wgn, inp.left) expression(wgn, inp.right) - if isinstance(inp.type, typing.TypeUInt8): + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) + + if mtyp == 'u8': if operator := U8_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeUInt32): + if mtyp == 'u32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := U32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeUInt64): + if mtyp == 'u64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := U64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if isinstance(inp.type, typing.TypeInt32): + if mtyp == 'i32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := I32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeInt64): + if mtyp == 'i64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := I64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if isinstance(inp.type, typing.TypeFloat32): + if mtyp == 'f32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f32.{operator}') return - if isinstance(inp.type, typing.TypeFloat64): + if mtyp == 'f64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f64.{operator}') return - raise NotImplementedError(expression, inp.type, inp.operator) + raise NotImplementedError(expression, inp.type_var, inp.operator) if isinstance(inp, ourlang.UnaryOp): expression(wgn, inp.right) - if isinstance(inp.type, typing.TypeFloat32): - if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) + + if mtyp == 'f32': + if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f32.{inp.operator}') return - if isinstance(inp.type, typing.TypeFloat64): - if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if mtyp == 'f64': + if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f64.{inp.operator}') return - if isinstance(inp.type, typing.TypeInt32): - if inp.operator == 'len': - if isinstance(inp.right.type, typing.TypeBytes): - wgn.i32.load() - return + # TODO: Broken after new type system + # if isinstance(inp.type, typing.TypeInt32): + # if inp.operator == 'len': + # if isinstance(inp.right.type, typing.TypeBytes): + # wgn.i32.load() + # return - if inp.operator == 'cast': - if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): - # Nothing to do, you can use an u8 value as a u32 no problem - return + # if inp.operator == 'cast': + # if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): + # # Nothing to do, you can use an u8 value as a u32 no problem + # return - raise NotImplementedError(expression, inp.type, inp.operator) + raise NotImplementedError(expression, inp.type_var, inp.operator) if isinstance(inp, ourlang.FunctionCall): for arg in inp.arguments: @@ -242,99 +294,119 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.add_statement('call', '${}'.format(inp.function.name)) return - if isinstance(inp, ourlang.AccessBytesIndex): - if not isinstance(inp.type, typing.TypeUInt8): - raise NotImplementedError(inp, inp.type) + if isinstance(inp, ourlang.Subscript): + assert inp.varref.type_var is not None, typing.ASSERTION_ERROR + tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive) + if tc_prim is None: + raise NotImplementedError(expression, inp, inp.varref.type_var) - expression(wgn, inp.varref) - expression(wgn, inp.index) - wgn.call(stdlib_types.__subscript_bytes__) - return + if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + if not isinstance(inp.index, ourlang.ConstantPrimitive): + raise NotImplementedError(expression, inp, inp.index) + if not isinstance(inp.index.value, int): + raise NotImplementedError(expression, inp, inp.index.value) - if isinstance(inp, ourlang.AccessStructMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.member) + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) + if mtyp is None: + raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp) - expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - return + if mtyp == 'u8': + # u8 operations are done using i32, since WASM does not have u8 operations + mtyp = 'i32' + elif mtyp == 'u32': + # u32 operations are done using i32, using _u operations + mtyp = 'i32' + elif mtyp == 'u64': + # u64 operations are done using i64, using _u operations + mtyp = 'i64' - if isinstance(inp, ourlang.AccessTupleMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.member) + tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript) + if tc_subs is None: + raise NotImplementedError(expression, inp, inp.varref.type_var) - expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - return + assert 0 < len(tc_subs.members) + tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth) + if tc_bits is None or len(tc_bits.oneof) > 1: + raise NotImplementedError(expression, inp, inp.varref.type_var) - if isinstance(inp, ourlang.AccessStaticArrayMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of static arrays - raise NotImplementedError(expression, inp, inp.member) + bitwidth = next(iter(tc_bits.oneof)) + if bitwidth % 8 != 0: + raise NotImplementedError(expression, inp, inp.varref.type_var) - if isinstance(inp.member, typing.TypeStaticArrayMember): expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) + wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value)) return - expression(wgn, inp.varref) - expression(wgn, inp.member) - wgn.i32.const(inp.static_array.member_type.alloc_size()) - wgn.i32.mul() - wgn.i32.add() - wgn.add_statement(f'{mtyp}.load') - return + raise NotImplementedError(expression, inp, inp.varref.type_var) + + + # TODO: Broken after new type system + # if isinstance(inp, ourlang.AccessBytesIndex): + # if not isinstance(inp.type, typing.TypeUInt8): + # raise NotImplementedError(inp, inp.type) + # + # expression(wgn, inp.varref) + # expression(wgn, inp.index) + # wgn.call(stdlib_types.__subscript_bytes__) + # return + + # if isinstance(inp, ourlang.AccessStructMember): + # mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) + # if mtyp is None: + # # In the future might extend this by having structs or tuples + # # as members of struct or tuples + # raise NotImplementedError(expression, inp, inp.member) + # + # expression(wgn, inp.varref) + # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) + # return + # + # if isinstance(inp, ourlang.AccessTupleMember): + # mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) + # if mtyp is None: + # # In the future might extend this by having structs or tuples + # # as members of struct or tuples + # raise NotImplementedError(expression, inp, inp.member) + # + # expression(wgn, inp.varref) + # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) + # return + # + # if isinstance(inp, ourlang.AccessStaticArrayMember): + # mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__) + # if mtyp is None: + # # In the future might extend this by having structs or tuples + # # as members of static arrays + # raise NotImplementedError(expression, inp, inp.member) + # + # expression(wgn, inp.varref) + # expression(wgn, inp.member) + # wgn.i32.const(inp.static_array.member_type.alloc_size()) + # wgn.i32.mul() + # wgn.i32.add() + # wgn.add_statement(f'{mtyp}.load') + # return if isinstance(inp, ourlang.Fold): expression_fold(wgn, inp) return - if isinstance(inp, ourlang.ModuleConstantReference): - if isinstance(inp.type, typing.TypeTuple): - assert isinstance(inp.definition.constant, ourlang.ConstantTuple) - assert inp.definition.data_block is not None, 'Combined values are memory stored' - assert inp.definition.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.definition.data_block.address) - return - - if isinstance(inp.type, typing.TypeStaticArray): - assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) - assert inp.definition.data_block is not None, 'Combined values are memory stored' - assert inp.definition.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.definition.data_block.address) - return - - assert inp.definition.data_block is None, 'Primitives are not memory stored' - - mtyp = LOAD_STORE_TYPE_MAP.get(inp.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.type) - - expression(wgn, inp.definition.constant) - return - raise NotImplementedError(expression, inp) def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: """ Compile: Fold expression """ - mtyp = LOAD_STORE_TYPE_MAP.get(inp.base.type.__class__) + assert inp.base.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.base.type_var) if mtyp is None: # In the future might extend this by having structs or tuples # as members of struct or tuples raise NotImplementedError(expression, inp, inp.base) + raise NotImplementedError('TODO: Broken after new type system') + if inp.iter.type.__class__.__name__ != 'TypeBytes': raise NotImplementedError(expression, inp, inp.iter.type) @@ -450,7 +522,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param: """ Compile: function argument """ - return (inp[0], type_(inp[1]), ) + return (inp.name, type_var(inp.type_var), ) def import_(inp: ourlang.Function) -> wasm.Import: """ @@ -466,7 +538,7 @@ def import_(inp: ourlang.Function) -> wasm.Import: function_argument(x) for x in inp.posonlyargs ], - type_(inp.returns) + type_var(inp.returns_type_var) ) def function(inp: ourlang.Function) -> wasm.Function: @@ -477,10 +549,10 @@ def function(inp: ourlang.Function) -> wasm.Function: wgn = WasmGenerator() - if isinstance(inp, ourlang.TupleConstructor): - _generate_tuple_constructor(wgn, inp) - elif isinstance(inp, ourlang.StructConstructor): - _generate_struct_constructor(wgn, inp) + if False: # TODO: isinstance(inp, ourlang.TupleConstructor): + pass # _generate_tuple_constructor(wgn, inp) + elif False: # TODO: isinstance(inp, ourlang.StructConstructor): + pass # _generate_struct_constructor(wgn, inp) else: for stat in inp.statements: statement(wgn, stat) @@ -496,7 +568,7 @@ def function(inp: ourlang.Function) -> wasm.Function: (k, v.wasm_type(), ) for k, v in wgn.locals.items() ], - type_(inp.returns), + type_var(inp.returns_type_var), wgn.statements ) @@ -555,38 +627,48 @@ def module_data(inp: ourlang.ModuleData) -> bytes: for block in inp.blocks: block.address = unalloc_ptr + 4 # 4 bytes for allocator header - data_list = [] + data_list: List[bytes] = [] for constant in block.data: - if isinstance(constant, ourlang.ConstantUInt8): + assert constant.type_var is not None + mtyp = typing.simplify(constant.type_var) + + if mtyp == 'u8': + assert isinstance(constant.value, int) data_list.append(module_data_u8(constant.value)) continue - if isinstance(constant, ourlang.ConstantUInt32): + if mtyp == 'u32': + assert isinstance(constant.value, int) data_list.append(module_data_u32(constant.value)) continue - if isinstance(constant, ourlang.ConstantUInt64): + if mtyp == 'u64': + assert isinstance(constant.value, int) data_list.append(module_data_u64(constant.value)) continue - if isinstance(constant, ourlang.ConstantInt32): + if mtyp == 'i32': + assert isinstance(constant.value, int) data_list.append(module_data_i32(constant.value)) continue - if isinstance(constant, ourlang.ConstantInt64): + if mtyp == 'i64': + assert isinstance(constant.value, int) data_list.append(module_data_i64(constant.value)) continue - if isinstance(constant, ourlang.ConstantFloat32): + if mtyp == 'f32': + assert isinstance(constant.value, float) data_list.append(module_data_f32(constant.value)) continue - if isinstance(constant, ourlang.ConstantFloat64): + if mtyp == 'f64': + assert isinstance(constant.value, float) data_list.append(module_data_f64(constant.value)) continue - raise NotImplementedError(constant) + raise NotImplementedError(constant, constant.type_var, mtyp) block_data = b''.join(data_list) @@ -636,48 +718,49 @@ def module(inp: ourlang.Module) -> wasm.Module: return result -def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None: - tmp_var = wgn.temp_var_i32('tuple_adr') - - # Allocated the required amounts of bytes in memory - wgn.i32.const(inp.tuple.alloc_size()) - wgn.call(stdlib_alloc.__alloc__) - wgn.local.set(tmp_var) - - # Store each member individually - for member in inp.tuple.members: - mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, member) - - wgn.local.get(tmp_var) - wgn.add_statement('local.get', f'$arg{member.idx}') - wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) - - # Return the allocated address - wgn.local.get(tmp_var) - -def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: - tmp_var = wgn.temp_var_i32('struct_adr') - - # Allocated the required amounts of bytes in memory - wgn.i32.const(inp.struct.alloc_size()) - wgn.call(stdlib_alloc.__alloc__) - wgn.local.set(tmp_var) - - # Store each member individually - for member in inp.struct.members: - mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, member) - - wgn.local.get(tmp_var) - wgn.add_statement('local.get', f'${member.name}') - wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) - - # Return the allocated address - wgn.local.get(tmp_var) +# TODO: Broken after new type system +# def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None: +# tmp_var = wgn.temp_var_i32('tuple_adr') +# +# # Allocated the required amounts of bytes in memory +# wgn.i32.const(inp.tuple.alloc_size()) +# wgn.call(stdlib_alloc.__alloc__) +# wgn.local.set(tmp_var) +# +# # Store each member individually +# for member in inp.tuple.members: +# mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) +# if mtyp is None: +# # In the future might extend this by having structs or tuples +# # as members of struct or tuples +# raise NotImplementedError(expression, inp, member) +# +# wgn.local.get(tmp_var) +# wgn.add_statement('local.get', f'$arg{member.idx}') +# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) +# +# # Return the allocated address +# wgn.local.get(tmp_var) +# +# def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: +# tmp_var = wgn.temp_var_i32('struct_adr') +# +# # Allocated the required amounts of bytes in memory +# wgn.i32.const(inp.struct.alloc_size()) +# wgn.call(stdlib_alloc.__alloc__) +# wgn.local.set(tmp_var) +# +# # Store each member individually +# for member in inp.struct.members: +# mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) +# if mtyp is None: +# # In the future might extend this by having structs or tuples +# # as members of struct or tuples +# raise NotImplementedError(expression, inp, member) +# +# wgn.local.get(tmp_var) +# wgn.add_statement('local.get', f'${member.name}') +# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) +# +# # Return the allocated address +# wgn.local.get(tmp_var) diff --git a/phasm/exceptions.py b/phasm/exceptions.py index b459c22..77c75e7 100644 --- a/phasm/exceptions.py +++ b/phasm/exceptions.py @@ -6,3 +6,8 @@ class StaticError(Exception): """ An error found during static analysis """ + +class TypingError(Exception): + """ + An error found during the typing phase + """ diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 5f19d2e..b0e605d 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -1,38 +1,31 @@ """ Contains the syntax tree for ourlang """ -from typing import Dict, List, Tuple, Optional, Union +from typing import Dict, List, Optional, Union import enum from typing_extensions import Final -WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) -WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) +WEBASSEMBLY_BUILTIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) +WEBASSEMBLY_BUILTIN_BYTES_OPS: Final = ('len', ) from .typing import ( - TypeBase, - TypeNone, - TypeBool, - TypeUInt8, TypeUInt32, TypeUInt64, - TypeInt32, TypeInt64, - TypeFloat32, TypeFloat64, - TypeBytes, - TypeTuple, TypeTupleMember, - TypeStaticArray, TypeStaticArrayMember, - TypeStruct, TypeStructMember, + TypeStruct, + + TypeVar, ) class Expression: """ An expression within a statement """ - __slots__ = ('type', ) + __slots__ = ('type_var', ) - type: TypeBase + type_var: Optional[TypeVar] - def __init__(self, type_: TypeBase) -> None: - self.type = type_ + def __init__(self) -> None: + self.type_var = None class Constant(Expression): """ @@ -40,88 +33,16 @@ class Constant(Expression): """ __slots__ = () -class ConstantUInt8(Constant): +class ConstantPrimitive(Constant): """ - An UInt8 constant value expression within a statement + An primitive constant value expression within a statement """ __slots__ = ('value', ) - value: int + value: Union[int, float] - def __init__(self, type_: TypeUInt8, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantUInt32(Constant): - """ - An UInt32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeUInt32, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantUInt64(Constant): - """ - An UInt64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeUInt64, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantInt32(Constant): - """ - An Int32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeInt32, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantInt64(Constant): - """ - An Int64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeInt64, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantFloat32(Constant): - """ - An Float32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: float - - def __init__(self, type_: TypeFloat32, value: float) -> None: - super().__init__(type_) - self.value = value - -class ConstantFloat64(Constant): - """ - An Float64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: float - - def __init__(self, type_: TypeFloat64, value: float) -> None: - super().__init__(type_) + def __init__(self, value: Union[int, float]) -> None: + super().__init__() self.value = value class ConstantTuple(Constant): @@ -130,35 +51,23 @@ class ConstantTuple(Constant): """ __slots__ = ('value', ) - value: List[Constant] + value: List[ConstantPrimitive] - def __init__(self, type_: TypeTuple, value: List[Constant]) -> None: - super().__init__(type_) - self.value = value - -class ConstantStaticArray(Constant): - """ - A StaticArray constant value expression within a statement - """ - __slots__ = ('value', ) - - value: List[Constant] - - def __init__(self, type_: TypeStaticArray, value: List[Constant]) -> None: - super().__init__(type_) + def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? + super().__init__() self.value = value class VariableReference(Expression): """ An variable reference expression within a statement """ - __slots__ = ('name', ) + __slots__ = ('variable', ) - name: str + variable: Union['ModuleConstantDef', 'FunctionParam'] # also possibly local - def __init__(self, type_: TypeBase, name: str) -> None: - super().__init__(type_) - self.name = name + def __init__(self, variable: Union['ModuleConstantDef', 'FunctionParam']) -> None: + super().__init__() + self.variable = variable class UnaryOp(Expression): """ @@ -169,8 +78,8 @@ class UnaryOp(Expression): operator: str right: Expression - def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None: - super().__init__(type_) + def __init__(self, operator: str, right: Expression) -> None: + super().__init__() self.operator = operator self.right = right @@ -185,8 +94,8 @@ class BinaryOp(Expression): left: Expression right: Expression - def __init__(self, type_: TypeBase, operator: str, left: Expression, right: Expression) -> None: - super().__init__(type_) + def __init__(self, operator: str, left: Expression, right: Expression) -> None: + super().__init__() self.operator = operator self.left = left @@ -202,73 +111,27 @@ class FunctionCall(Expression): arguments: List[Expression] def __init__(self, function: 'Function') -> None: - super().__init__(function.returns) + super().__init__() self.function = function self.arguments = [] -class AccessBytesIndex(Expression): +class Subscript(Expression): """ - Access a bytes index for reading + A subscript, for example to refer to a static array or tuple + by index """ __slots__ = ('varref', 'index', ) varref: VariableReference index: Expression - def __init__(self, type_: TypeBase, varref: VariableReference, index: Expression) -> None: - super().__init__(type_) + def __init__(self, varref: VariableReference, index: Expression) -> None: + super().__init__() self.varref = varref self.index = index -class AccessStructMember(Expression): - """ - Access a struct member for reading of writing - """ - __slots__ = ('varref', 'member', ) - - varref: VariableReference - member: TypeStructMember - - def __init__(self, varref: VariableReference, member: TypeStructMember) -> None: - super().__init__(member.type) - - self.varref = varref - self.member = member - -class AccessTupleMember(Expression): - """ - Access a tuple member for reading of writing - """ - __slots__ = ('varref', 'member', ) - - varref: VariableReference - member: TypeTupleMember - - def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None: - super().__init__(member.type) - - self.varref = varref - self.member = member - -class AccessStaticArrayMember(Expression): - """ - Access a tuple member for reading of writing - """ - __slots__ = ('varref', 'static_array', 'member', ) - - varref: Union['ModuleConstantReference', VariableReference] - static_array: TypeStaticArray - member: Union[Expression, TypeStaticArrayMember] - - def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None: - super().__init__(static_array.member_type) - - self.varref = varref - self.static_array = static_array - self.member = member - class Fold(Expression): """ A (left or right) fold @@ -287,31 +150,18 @@ class Fold(Expression): def __init__( self, - type_: TypeBase, dir_: Direction, func: 'Function', base: Expression, iter_: Expression, ) -> None: - super().__init__(type_) + super().__init__() self.dir = dir_ self.func = func self.base = base self.iter = iter_ -class ModuleConstantReference(Expression): - """ - An reference to a module constant expression within a statement - """ - __slots__ = ('definition', ) - - definition: 'ModuleConstantDef' - - def __init__(self, type_: TypeBase, definition: 'ModuleConstantDef') -> None: - super().__init__(type_) - self.definition = definition - class Statement: """ A statement within a function @@ -348,20 +198,34 @@ class StatementIf(Statement): self.statements = [] self.else_statements = [] -FunctionParam = Tuple[str, TypeBase] +class FunctionParam: + """ + A parameter for a Function + """ + __slots__ = ('name', 'type_str', 'type_var', ) + + name: str + type_str: str + type_var: Optional[TypeVar] + + def __init__(self, name: str, type_str: str) -> None: + self.name = name + self.type_str = type_str + self.type_var = None class Function: """ A function processes input and produces output """ - __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', ) + __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns_str', 'returns_type_var', 'posonlyargs', ) name: str lineno: int exported: bool imported: bool statements: List[Statement] - returns: TypeBase + returns_str: str + returns_type_var: Optional[TypeVar] posonlyargs: List[FunctionParam] def __init__(self, name: str, lineno: int) -> None: @@ -370,66 +234,70 @@ class Function: self.exported = False self.imported = False self.statements = [] - self.returns = TypeNone() + self.returns_str = 'None' + self.returns_type_var = None self.posonlyargs = [] -class StructConstructor(Function): - """ - The constructor method for a struct - - A function will generated to instantiate a struct. The arguments - will be the defaults - """ - __slots__ = ('struct', ) - - struct: TypeStruct - - def __init__(self, struct: TypeStruct) -> None: - super().__init__(f'@{struct.name}@__init___@', -1) - - self.returns = struct - - for mem in struct.members: - self.posonlyargs.append((mem.name, mem.type, )) - - self.struct = struct - -class TupleConstructor(Function): - """ - The constructor method for a tuple - """ - __slots__ = ('tuple', ) - - tuple: TypeTuple - - def __init__(self, tuple_: TypeTuple) -> None: - name = tuple_.render_internal_name() - - super().__init__(f'@{name}@__init___@', -1) - - self.returns = tuple_ - - for mem in tuple_.members: - self.posonlyargs.append((f'arg{mem.idx}', mem.type, )) - - self.tuple = tuple_ +# TODO: Broken after new type system +# class StructConstructor(Function): +# """ +# The constructor method for a struct +# +# A function will generated to instantiate a struct. The arguments +# will be the defaults +# """ +# __slots__ = ('struct', ) +# +# struct: TypeStruct +# +# def __init__(self, struct: TypeStruct) -> None: +# super().__init__(f'@{struct.name}@__init___@', -1) +# +# self.returns = struct +# +# for mem in struct.members: +# self.posonlyargs.append(FunctionParam(mem.name, mem.type, )) +# +# self.struct = struct +# +# class TupleConstructor(Function): +# """ +# The constructor method for a tuple +# """ +# __slots__ = ('tuple', ) +# +# tuple: TypeTuple +# +# def __init__(self, tuple_: TypeTuple) -> None: +# name = tuple_.render_internal_name() +# +# super().__init__(f'@{name}@__init___@', -1) +# +# self.returns = tuple_ +# +# for mem in tuple_.members: +# self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, )) +# +# self.tuple = tuple_ class ModuleConstantDef: """ A constant definition within a module """ - __slots__ = ('name', 'lineno', 'type', 'constant', 'data_block', ) + __slots__ = ('name', 'lineno', 'type_str', 'type_var', 'constant', 'data_block', ) name: str lineno: int - type: TypeBase + type_str: str + type_var: Optional[TypeVar] constant: Constant data_block: Optional['ModuleDataBlock'] - def __init__(self, name: str, lineno: int, type_: TypeBase, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: + def __init__(self, name: str, lineno: int, type_str: str, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: self.name = name self.lineno = lineno - self.type = type_ + self.type_str = type_str + self.type_var = None self.constant = constant self.data_block = data_block @@ -439,10 +307,10 @@ class ModuleDataBlock: """ __slots__ = ('data', 'address', ) - data: List[Constant] + data: List[ConstantPrimitive] address: Optional[int] - def __init__(self, data: List[Constant]) -> None: + def __init__(self, data: List[ConstantPrimitive]) -> None: self.data = data self.address = None @@ -464,23 +332,11 @@ class Module: __slots__ = ('data', 'types', 'structs', 'constant_defs', 'functions',) data: ModuleData - types: Dict[str, TypeBase] structs: Dict[str, TypeStruct] constant_defs: Dict[str, ModuleConstantDef] functions: Dict[str, Function] def __init__(self) -> None: - self.types = { - 'None': TypeNone(), - 'u8': TypeUInt8(), - 'u32': TypeUInt32(), - 'u64': TypeUInt64(), - 'i32': TypeInt32(), - 'i64': TypeInt64(), - 'f32': TypeFloat32(), - 'f64': TypeFloat64(), - 'bytes': TypeBytes(), - } self.data = ModuleData() self.structs = {} self.constant_defs = {} diff --git a/phasm/parser.py b/phasm/parser.py index d95bfce..51019be 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -6,48 +6,33 @@ from typing import Any, Dict, NoReturn, Union import ast from .typing import ( - TypeBase, - TypeUInt8, - TypeUInt32, - TypeUInt64, - TypeInt32, - TypeInt64, - TypeFloat32, - TypeFloat64, - TypeBytes, + BUILTIN_TYPES, + TypeStruct, TypeStructMember, - TypeTuple, - TypeTupleMember, - TypeStaticArray, - TypeStaticArrayMember, ) -from . import codestyle from .exceptions import StaticError from .ourlang import ( - WEBASSEMBLY_BUILDIN_FLOAT_OPS, + WEBASSEMBLY_BUILTIN_FLOAT_OPS, Module, ModuleDataBlock, Function, Expression, - AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, BinaryOp, - Constant, - ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64, - ConstantUInt8, ConstantUInt32, ConstantUInt64, - ConstantTuple, ConstantStaticArray, + ConstantPrimitive, ConstantTuple, - FunctionCall, - StructConstructor, TupleConstructor, + FunctionCall, Subscript, + # StructConstructor, TupleConstructor, UnaryOp, VariableReference, - Fold, ModuleConstantReference, + Fold, Statement, StatementIf, StatementPass, StatementReturn, + FunctionParam, ModuleConstantDef, ) @@ -60,7 +45,7 @@ def phasm_parse(source: str) -> Module: our_visitor = OurVisitor() return our_visitor.visit_Module(res) -OurLocals = Dict[str, TypeBase] +OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants? class OurVisitor: """ @@ -95,15 +80,16 @@ class OurVisitor: module.constant_defs[res.name] = res - if isinstance(res, TypeStruct): - if res.name in module.structs: - raise StaticError( - f'{res.name} already defined on line {module.structs[res.name].lineno}' - ) - - module.structs[res.name] = res - constructor = StructConstructor(res) - module.functions[constructor.name] = constructor + # TODO: Broken after type system + # if isinstance(res, TypeStruct): + # if res.name in module.structs: + # raise StaticError( + # f'{res.name} already defined on line {module.structs[res.name].lineno}' + # ) + # + # module.structs[res.name] = res + # constructor = StructConstructor(res) + # module.functions[constructor.name] = constructor if isinstance(res, Function): if res.name in module.functions: @@ -141,7 +127,7 @@ class OurVisitor: if not arg.annotation: _raise_static_error(node, 'Type is required') - function.posonlyargs.append(( + function.posonlyargs.append(FunctionParam( arg.arg, self.visit_type(module, arg.annotation), )) @@ -167,7 +153,7 @@ class OurVisitor: function.imported = True if node.returns: - function.returns = self.visit_type(module, node.returns) + function.returns_str = self.visit_type(module, node.returns) _not_implemented(not node.type_comment, 'FunctionDef.type_comment') @@ -195,6 +181,7 @@ class OurVisitor: if stmt.simple != 1: raise NotImplementedError('Class with non-simple arguments') + raise NotImplementedError('TODO: Broken after new type system') member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) struct.members.append(member) @@ -208,34 +195,22 @@ class OurVisitor: if not isinstance(node.target.ctx, ast.Store): _raise_static_error(node, 'Must be load context') - exp_type = self.visit_type(module, node.annotation) - - if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, ast.Constant): - _raise_static_error(node, 'Must be constant') - - constant = ModuleConstantDef( + if isinstance(node.value, ast.Constant): + return ModuleConstantDef( node.target.id, node.lineno, - exp_type, - self.visit_Module_Constant(module, exp_type, node.value), + self.visit_type(module, node.annotation), + self.visit_Module_Constant(module, node.value), None, ) - return constant - - if isinstance(exp_type, TypeTuple): - if not isinstance(node.value, ast.Tuple): - _raise_static_error(node, 'Must be tuple') - - if len(exp_type.members) != len(node.value.elts): - _raise_static_error(node, 'Invalid number of tuple values') + if isinstance(node.value, ast.Tuple): tuple_data = [ - self.visit_Module_Constant(module, mem.type, arg_node) - for arg_node, mem in zip(node.value.elts, exp_type.members) + self.visit_Module_Constant(module, arg_node) + for arg_node in node.value.elts if isinstance(arg_node, ast.Constant) ] - if len(exp_type.members) != len(tuple_data): + if len(node.value.elts) != len(tuple_data): _raise_static_error(node, 'Tuple arguments must be constants') # Allocate the data @@ -246,40 +221,69 @@ class OurVisitor: return ModuleConstantDef( node.target.id, node.lineno, - exp_type, - ConstantTuple(exp_type, tuple_data), + self.visit_type(module, node.annotation), + ConstantTuple(tuple_data), data_block, ) - if isinstance(exp_type, TypeStaticArray): - if not isinstance(node.value, ast.Tuple): - _raise_static_error(node, 'Must be static array') + raise NotImplementedError('TODO: Broken after new typing system') - if len(exp_type.members) != len(node.value.elts): - _raise_static_error(node, 'Invalid number of static array values') - - static_array_data = [ - self.visit_Module_Constant(module, exp_type.member_type, arg_node) - for arg_node in node.value.elts - if isinstance(arg_node, ast.Constant) - ] - if len(exp_type.members) != len(static_array_data): - _raise_static_error(node, 'Static array arguments must be constants') - - # Allocate the data - data_block = ModuleDataBlock(static_array_data) - module.data.blocks.append(data_block) - - # Then return the constant as a pointer - return ModuleConstantDef( - node.target.id, - node.lineno, - exp_type, - ConstantStaticArray(exp_type, static_array_data), - data_block, - ) - - raise NotImplementedError(f'{node} on Module AnnAssign') + # if isinstance(exp_type, TypeTuple): + # if not isinstance(node.value, ast.Tuple): + # _raise_static_error(node, 'Must be tuple') + # + # if len(exp_type.members) != len(node.value.elts): + # _raise_static_error(node, 'Invalid number of tuple values') + # + # tuple_data = [ + # self.visit_Module_Constant(module, arg_node) + # for arg_node, mem in zip(node.value.elts, exp_type.members) + # if isinstance(arg_node, ast.Constant) + # ] + # if len(exp_type.members) != len(tuple_data): + # _raise_static_error(node, 'Tuple arguments must be constants') + # + # # Allocate the data + # data_block = ModuleDataBlock(tuple_data) + # module.data.blocks.append(data_block) + # + # # Then return the constant as a pointer + # return ModuleConstantDef( + # node.target.id, + # node.lineno, + # exp_type, + # ConstantTuple(tuple_data), + # data_block, + # ) + # + # if isinstance(exp_type, TypeStaticArray): + # if not isinstance(node.value, ast.Tuple): + # _raise_static_error(node, 'Must be static array') + # + # if len(exp_type.members) != len(node.value.elts): + # _raise_static_error(node, 'Invalid number of static array values') + # + # static_array_data = [ + # self.visit_Module_Constant(module, arg_node) + # for arg_node in node.value.elts + # if isinstance(arg_node, ast.Constant) + # ] + # if len(exp_type.members) != len(static_array_data): + # _raise_static_error(node, 'Static array arguments must be constants') + # + # # Allocate the data + # data_block = ModuleDataBlock(static_array_data) + # module.data.blocks.append(data_block) + # + # # Then return the constant as a pointer + # return ModuleConstantDef( + # node.target.id, + # node.lineno, + # ConstantStaticArray(static_array_data), + # data_block, + # ) + # + # raise NotImplementedError(f'{node} on Module AnnAssign') def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None: if isinstance(node, ast.FunctionDef): @@ -297,7 +301,10 @@ class OurVisitor: def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: function = module.functions[node.name] - our_locals = dict(function.posonlyargs) + our_locals: OurLocals = { + x.name: x + for x in function.posonlyargs + } for stmt in node.body: function.statements.append( @@ -311,12 +318,12 @@ class OurVisitor: _raise_static_error(node, 'Return must have an argument') return StatementReturn( - self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value) ) if isinstance(node, ast.If): result = StatementIf( - self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.test) ) for stmt in node.body: @@ -336,7 +343,7 @@ class OurVisitor: raise NotImplementedError(f'{node} as stmt in FunctionDef') - def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression: + def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression: if isinstance(node, ast.BinOp): if isinstance(node.op, ast.Add): operator = '+' @@ -361,10 +368,9 @@ class OurVisitor: # e.g. you can do `"hello" * 3` with the code below (yet) return BinaryOp( - exp_type, operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right), ) if isinstance(node, ast.UnaryOp): @@ -376,9 +382,8 @@ class OurVisitor: raise NotImplementedError(f'Operator {node.op}') return UnaryOp( - exp_type, operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.operand), ) if isinstance(node, ast.Compare): @@ -398,28 +403,27 @@ class OurVisitor: # e.g. you can do `"hello" * 3` with the code below (yet) return BinaryOp( - exp_type, operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]), ) if isinstance(node, ast.Call): - return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) + return self.visit_Module_FunctionDef_Call(module, function, our_locals, node) if isinstance(node, ast.Constant): return self.visit_Module_Constant( - module, exp_type, node, + module, node, ) if isinstance(node, ast.Attribute): return self.visit_Module_FunctionDef_Attribute( - module, function, our_locals, exp_type, node, + module, function, our_locals, node, ) if isinstance(node, ast.Subscript): return self.visit_Module_FunctionDef_Subscript( - module, function, our_locals, exp_type, node, + module, function, our_locals, node, ) if isinstance(node, ast.Name): @@ -427,45 +431,41 @@ class OurVisitor: _raise_static_error(node, 'Must be load context') if node.id in our_locals: - act_type = our_locals[node.id] - if exp_type != act_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') - - return VariableReference(act_type, node.id) + param = our_locals[node.id] + return VariableReference(param) if node.id in module.constant_defs: cdef = module.constant_defs[node.id] - if exp_type != cdef.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(cdef.type)}') - - return ModuleConstantReference(exp_type, cdef) + return VariableReference(cdef) _raise_static_error(node, f'Undefined variable {node.id}') if isinstance(node, ast.Tuple): - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') + raise NotImplementedError('TODO: Broken after new type system') - if isinstance(exp_type, TypeTuple): - if len(exp_type.members) != len(node.elts): - _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') - - tuple_constructor = TupleConstructor(exp_type) - - func = module.functions[tuple_constructor.name] - - result = FunctionCall(func) - result.arguments = [ - self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) - for arg_node, mem in zip(node.elts, exp_type.members) - ] - return result - - _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple') + # if not isinstance(node.ctx, ast.Load): + # _raise_static_error(node, 'Must be load context') + # + # if isinstance(exp_type, TypeTuple): + # if len(exp_type.members) != len(node.elts): + # _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') + # + # tuple_constructor = TupleConstructor(exp_type) + # + # func = module.functions[tuple_constructor.name] + # + # result = FunctionCall(func) + # result.arguments = [ + # self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) + # for arg_node, mem in zip(node.elts, exp_type.members) + # ] + # return result + # + # _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple') raise NotImplementedError(f'{node} as expr in FunctionDef') - def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: + def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? @@ -474,48 +474,37 @@ class OurVisitor: if not isinstance(node.func.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.func.id in module.structs: - struct = module.structs[node.func.id] - struct_constructor = StructConstructor(struct) - - func = module.functions[struct_constructor.name] - elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: - if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )): - _raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}') - + # if node.func.id in module.structs: + # raise NotImplementedError('TODO: Broken after new type system') + # struct = module.structs[node.func.id] + # struct_constructor = StructConstructor(struct) + # + # func = module.functions[struct_constructor.name] + if node.func.id in WEBASSEMBLY_BUILTIN_FLOAT_OPS: if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( - exp_type, 'sqrt', - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'u32': - if not isinstance(exp_type, TypeUInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') # FIXME: This is a stub, proper casting is todo return UnaryOp( - exp_type, 'cast', - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'len': - if not isinstance(exp_type, TypeInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( - exp_type, 'len', - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'foldl': # TODO: This should a much more generic function! @@ -538,21 +527,13 @@ class OurVisitor: if 2 != len(func.posonlyargs): _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') - if exp_type.__class__ != func.returns.__class__: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') - - if func.returns.__class__ != func.posonlyargs[0][1].__class__: - _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}') - - if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__: - _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') + raise NotImplementedError('TODO: Broken after new type system') return Fold( - exp_type, Fold.Direction.LEFT, func, - self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]), - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[2]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), ) else: if node.func.id not in module.functions: @@ -560,49 +541,46 @@ class OurVisitor: func = module.functions[node.func.id] - if func.returns != exp_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') - if len(func.posonlyargs) != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') result = FunctionCall(func) result.arguments.extend( - self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) - for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) + self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) + for arg_expr, param in zip(node.args, func.posonlyargs) ) return result - def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression: - del module - del function + def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: + raise NotImplementedError('Broken after new type system') + # del module + # del function + # + # if not isinstance(node.value, ast.Name): + # _raise_static_error(node, 'Must reference a name') + # + # if not isinstance(node.ctx, ast.Load): + # _raise_static_error(node, 'Must be load context') + # + # if not node.value.id in our_locals: + # _raise_static_error(node, f'Undefined variable {node.value.id}') + # + # param = our_locals[node.value.id] + # + # node_typ = param.type + # if not isinstance(node_typ, TypeStruct): + # _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') + # + # member = node_typ.get_member(node.attr) + # if member is None: + # _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') + # + # return AccessStructMember( + # VariableReference(param), + # member, + # ) - if not isinstance(node.value, ast.Name): - _raise_static_error(node, 'Must reference a name') - - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if not node.value.id in our_locals: - _raise_static_error(node, f'Undefined variable {node.value.id}') - - node_typ = our_locals[node.value.id] - if not isinstance(node_typ, TypeStruct): - _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') - - member = node_typ.get_member(node.attr) - if member is None: - _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') - - if exp_type != member.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') - - return AccessStructMember( - VariableReference(node_typ, node.value.id), - member, - ) - - def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression: + def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must reference a name') @@ -612,154 +590,93 @@ class OurVisitor: if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - varref: Union[ModuleConstantReference, VariableReference] + varref: VariableReference if node.value.id in our_locals: - node_typ = our_locals[node.value.id] - varref = VariableReference(node_typ, node.value.id) + param = our_locals[node.value.id] + varref = VariableReference(param) elif node.value.id in module.constant_defs: constant_def = module.constant_defs[node.value.id] - node_typ = constant_def.type - varref = ModuleConstantReference(node_typ, constant_def) + varref = VariableReference(constant_def) else: _raise_static_error(node, f'Undefined variable {node.value.id}') slice_expr = self.visit_Module_FunctionDef_expr( - module, function, our_locals, module.types['u32'], node.slice.value, + module, function, our_locals, node.slice.value, ) - if isinstance(node_typ, TypeBytes): - t_u8 = module.types['u8'] - if exp_type != t_u8: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{codestyle.expression(slice_expr)}] is actually {codestyle.type_(t_u8)}') + return Subscript(varref, slice_expr) - if isinstance(varref, ModuleConstantReference): - raise NotImplementedError(f'{node} from module constant') + # if isinstance(node_typ, TypeBytes): + # if isinstance(varref, ModuleConstantReference): + # raise NotImplementedError(f'{node} from module constant') + # + # return AccessBytesIndex( + # varref, + # slice_expr, + # ) + # + # if isinstance(node_typ, TypeTuple): + # if not isinstance(slice_expr, ConstantPrimitive): + # _raise_static_error(node, 'Must subscript using a constant index') + # + # idx = slice_expr.value + # + # if not isinstance(idx, int): + # _raise_static_error(node, 'Must subscript using a constant integer index') + # + # if not (0 <= idx < len(node_typ.members)): + # _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') + # + # tuple_member = node_typ.members[idx] + # + # if isinstance(varref, ModuleConstantReference): + # raise NotImplementedError(f'{node} from module constant') + # + # return AccessTupleMember( + # varref, + # tuple_member, + # ) + # + # if isinstance(node_typ, TypeStaticArray): + # if not isinstance(slice_expr, ConstantPrimitive): + # return AccessStaticArrayMember( + # varref, + # node_typ, + # slice_expr, + # ) + # + # idx = slice_expr.value + # + # if not isinstance(idx, int): + # _raise_static_error(node, 'Must subscript using an integer index') + # + # if not (0 <= idx < len(node_typ.members)): + # _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') + # + # static_array_member = node_typ.members[idx] + # + # return AccessStaticArrayMember( + # varref, + # node_typ, + # static_array_member, + # ) + # + # _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') - return AccessBytesIndex( - t_u8, - varref, - slice_expr, - ) - - if isinstance(node_typ, TypeTuple): - if not isinstance(slice_expr, ConstantUInt32): - _raise_static_error(node, 'Must subscript using a constant index') - - idx = slice_expr.value - - if len(node_typ.members) <= idx: - _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') - - tuple_member = node_typ.members[idx] - if exp_type != tuple_member.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(tuple_member.type)}') - - if isinstance(varref, ModuleConstantReference): - raise NotImplementedError(f'{node} from module constant') - - return AccessTupleMember( - varref, - tuple_member, - ) - - if isinstance(node_typ, TypeStaticArray): - if exp_type != node_typ.member_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}') - - if not isinstance(slice_expr, ConstantInt32): - return AccessStaticArrayMember( - varref, - node_typ, - slice_expr, - ) - - idx = slice_expr.value - - if len(node_typ.members) <= idx: - _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') - - static_array_member = node_typ.members[idx] - - return AccessStaticArrayMember( - varref, - node_typ, - static_array_member, - ) - - _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') - - def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant: + def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive: del module _not_implemented(node.kind is None, 'Constant.kind') - if isinstance(exp_type, TypeUInt8): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') + if isinstance(node.value, (int, float, )): + return ConstantPrimitive(node.value) - if node.value < 0 or node.value > 255: - _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') + raise NotImplementedError(f'{node.value} as constant') - return ConstantUInt8(exp_type, node.value) - - if isinstance(exp_type, TypeUInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 4294967295: - _raise_static_error(node, 'Integer value out of range') - - return ConstantUInt32(exp_type, node.value) - - if isinstance(exp_type, TypeUInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 18446744073709551615: - _raise_static_error(node, 'Integer value out of range') - - return ConstantUInt64(exp_type, node.value) - - if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < -2147483648 or node.value > 2147483647: - _raise_static_error(node, 'Integer value out of range') - - return ConstantInt32(exp_type, node.value) - - if isinstance(exp_type, TypeInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < -9223372036854775808 or node.value > 9223372036854775807: - _raise_static_error(node, 'Integer value out of range') - - return ConstantInt64(exp_type, node.value) - - if isinstance(exp_type, TypeFloat32): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat32(exp_type, node.value) - - if isinstance(exp_type, TypeFloat64): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat64(exp_type, node.value) - - raise NotImplementedError(f'{node} as const for type {exp_type}') - - def visit_type(self, module: Module, node: ast.expr) -> TypeBase: + def visit_type(self, module: Module, node: ast.expr) -> str: if isinstance(node, ast.Constant): if node.value is None: - return module.types['None'] + return 'None' _raise_static_error(node, f'Unrecognized type {node.value}') @@ -767,8 +684,10 @@ class OurVisitor: if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.id in module.types: - return module.types[node.id] + if node.id in BUILTIN_TYPES: + return node.id + + raise NotImplementedError('TODO: Broken after type system') if node.id in module.structs: return module.structs[node.id] @@ -787,50 +706,35 @@ class OurVisitor: if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.value.id in module.types: - member_type = module.types[node.value.id] - else: + if node.value.id not in BUILTIN_TYPES: # FIXME: Tuple of tuples? _raise_static_error(node, f'Unrecognized type {node.value.id}') - type_static_array = TypeStaticArray(member_type) - - offset = 0 - - for idx in range(node.slice.value.value): - static_array_member = TypeStaticArrayMember(idx, offset) - - type_static_array.members.append(static_array_member) - offset += member_type.alloc_size() - - key = f'{node.value.id}[{node.slice.value.value}]' - - if key not in module.types: - module.types[key] = type_static_array - - return module.types[key] + return f'{node.value.id}[{node.slice.value.value}]' if isinstance(node, ast.Tuple): - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') + raise NotImplementedError('TODO: Broken after new type system') - type_tuple = TypeTuple() - - offset = 0 - - for idx, elt in enumerate(node.elts): - tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset) - - type_tuple.members.append(tuple_member) - offset += tuple_member.type.alloc_size() - - key = type_tuple.render_internal_name() - - if key not in module.types: - module.types[key] = type_tuple - constructor = TupleConstructor(type_tuple) - module.functions[constructor.name] = constructor - - return module.types[key] + # if not isinstance(node.ctx, ast.Load): + # _raise_static_error(node, 'Must be load context') + # + # type_tuple = TypeTuple() + # + # offset = 0 + # + # for idx, elt in enumerate(node.elts): + # tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset) + # + # type_tuple.members.append(tuple_member) + # offset += tuple_member.type.alloc_size() + # + # key = type_tuple.render_internal_name() + # + # if key not in module.types: + # module.types[key] = type_tuple + # constructor = TupleConstructor(type_tuple) + # module.functions[constructor.name] = constructor + # + # return module.types[key] raise NotImplementedError(f'{node} as type') diff --git a/phasm/stdlib/alloc.py b/phasm/stdlib/alloc.py index 2761bfb..8c5742d 100644 --- a/phasm/stdlib/alloc.py +++ b/phasm/stdlib/alloc.py @@ -26,7 +26,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32: g.i32.const(0) g.return_() - del alloc_size # TODO + del alloc_size # TODO: Actual implement using a previously freed block g.unreachable() return i32('return') # To satisfy mypy diff --git a/phasm/typer.py b/phasm/typer.py new file mode 100644 index 0000000..56b1f05 --- /dev/null +++ b/phasm/typer.py @@ -0,0 +1,192 @@ +""" +Type checks and enriches the given ast +""" +from . import ourlang + +from .exceptions import TypingError +from .typing import ( + Context, + TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript, + TypeVar, + from_str, +) + +def phasm_type(inp: ourlang.Module) -> None: + module(inp) + +def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar: + if isinstance(inp, ourlang.ConstantPrimitive): + result = ctx.new_var() + + if isinstance(inp.value, int): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + + # Need at least this many bits to store this constant value + result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) + # Don't dictate anything about signedness - you can use a signed + # constant in an unsigned variable if the bits fit + result.add_constraint(TypeConstraintSigned(None)) + + result.add_location(str(inp.value)) + + inp.type_var = result + + return result + + if isinstance(inp.value, float): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + + # We don't have fancy logic here to detect if the float constant + # fits in the given type. There a number of edge cases to consider, + # before implementing this. + + # 1) It may fit anyhow + # e.g., if the user has 3.14 as a float constant, neither a + # f32 nor a f64 can really fit this value. But does that mean + # we should throw an error? + + # If we'd implement it, we'd want to convert it to hex using + # inp.value.hex(), which would give us the mantissa and exponent. + # We can use those to determine what bit size the value should be in. + + # If that doesn't work out, we'd need another way to calculate the + # difference between what was written and what actually gets stored + # in memory, and warn if the difference is beyond a treshold. + + result.add_location(str(inp.value)) + + inp.type_var = result + + return result + + raise NotImplementedError(constant, inp, inp.value) + + if isinstance(inp, ourlang.ConstantTuple): + result = ctx.new_var() + + result.add_constraint(TypeConstraintSubscript(members=( + constant(ctx, x) + for x in inp.value + ))) + result.add_location(str(inp.value)) + + inp.type_var = result + + return result + + raise NotImplementedError(constant, inp) + +def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': + if isinstance(inp, ourlang.Constant): + return constant(ctx, inp) + + if isinstance(inp, ourlang.VariableReference): + assert inp.variable.type_var is not None + + inp.type_var = inp.variable.type_var + return inp.variable.type_var + + if isinstance(inp, ourlang.UnaryOp): + # TODO: Simplified version + if inp.operator not in ('sqrt', ): + raise NotImplementedError(expression, inp, inp.operator) + + right = expression(ctx, inp.right) + + inp.type_var = right + + return right + + if isinstance(inp, ourlang.BinaryOp): + if inp.operator in ('+', '-', '*', '|', '&', '^'): + left = expression(ctx, inp.left) + right = expression(ctx, inp.right) + ctx.unify(left, right) + + inp.type_var = left + return left + + if inp.operator in ('<<', '>>', ): + inp.type_var = ctx.new_var() + inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, ))) + inp.type_var.add_constraint(TypeConstraintSigned(False)) + + left = expression(ctx, inp.left) + right = expression(ctx, inp.right) + ctx.unify(left, right) + + ctx.unify(inp.type_var, left) + + return left + + raise NotImplementedError(expression, inp, inp.operator) + + if isinstance(inp, ourlang.FunctionCall): + assert inp.function.returns_type_var is not None + + for param, expr in zip(inp.function.posonlyargs, inp.arguments): + assert param.type_var is not None + + arg = expression(ctx, expr) + ctx.unify(param.type_var, arg) + + return inp.function.returns_type_var + + if isinstance(inp, ourlang.Subscript): + if not isinstance(inp.index, ourlang.ConstantPrimitive): + raise NotImplementedError(expression, inp, inp.index) + if not isinstance(inp.index.value, int): + raise NotImplementedError(expression, inp, inp.index.value) + + expression(ctx, inp.varref) + assert inp.varref.type_var is not None + + # TODO: I'd much rather resolve this using the narrow functions + tc_subs = inp.varref.type_var.get_constraint(TypeConstraintSubscript) + if tc_subs is None: + raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None + + try: + # TODO: I'd much rather resolve this using the narrow functions + member = tc_subs.members[inp.index.value] + except IndexError: + raise TypingError(f'Type cannot be subscripted with index {inp.index.value}: {inp.varref.type_var}') from None + + inp.type_var = member + return member + + raise NotImplementedError(expression, inp) + +def function(ctx: Context, inp: ourlang.Function) -> None: + if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn): + raise NotImplementedError('Functions with not just a return statement') + typ = expression(ctx, inp.statements[0].value) + + assert inp.returns_type_var is not None + ctx.unify(inp.returns_type_var, typ) + +def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> None: + constant(ctx, inp.constant) + + if inp.type_str is None: + inp.type_var = ctx.new_var() + else: + inp.type_var = from_str(ctx, inp.type_str) + + assert inp.constant.type_var is not None + ctx.unify(inp.type_var, inp.constant.type_var) + +def module(inp: ourlang.Module) -> None: + ctx = Context() + + for func in inp.functions.values(): + func.returns_type_var = from_str(ctx, func.returns_str, f'{func.name}.(returns)') + for param in func.posonlyargs: + param.type_var = from_str(ctx, param.type_str, f'{func.name}.{param.name}') + + for cdef in inp.constant_defs.values(): + module_constant_def(ctx, cdef) + + for func in inp.functions.values(): + function(ctx, func) diff --git a/phasm/typing.py b/phasm/typing.py index e56f7a9..468e537 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -1,7 +1,13 @@ """ The phasm type system """ -from typing import Optional, List +from typing import Callable, Dict, Iterable, Optional, List, Set, Type +from typing import TypeVar as MyPyTypeVar + +import enum +import re + +from .exceptions import TypingError class TypeBase: """ @@ -15,88 +21,6 @@ class TypeBase: """ raise NotImplementedError(self, 'alloc_size') -class TypeNone(TypeBase): - """ - The None (or Void) type - """ - __slots__ = () - -class TypeBool(TypeBase): - """ - The boolean type - """ - __slots__ = () - -class TypeUInt8(TypeBase): - """ - The Integer type, unsigned and 8 bits wide - - Note that under the hood we need to use i32 to represent - these values in expressions. So we need to add some operations - to make sure the math checks out. - - So while this does save bytes in memory, it may not actually - speed up or improve your code. - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 # Int32 under the hood - -class TypeUInt32(TypeBase): - """ - The Integer type, unsigned and 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeUInt64(TypeBase): - """ - The Integer type, unsigned and 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - -class TypeInt32(TypeBase): - """ - The Integer type, signed and 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeInt64(TypeBase): - """ - The Integer type, signed and 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - -class TypeFloat32(TypeBase): - """ - The Float type, 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeFloat64(TypeBase): - """ - The Float type, 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - class TypeBytes(TypeBase): """ The bytes type @@ -200,3 +124,487 @@ class TypeStruct(TypeBase): x.type.alloc_size() for x in self.members ) + +## NEW STUFF BELOW + +# This error can also mean that the typer somewhere forgot to write a type +# back to the AST. If so, we need to fix the typer. +ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method' + + +class TypingNarrowProtoError(TypingError): + """ + A proto error when trying to narrow two types + + This gets turned into a TypingNarrowError by the unify method + """ + # FIXME: Use consistent naming for unify / narrow / entangle + +class TypingNarrowError(TypingError): + """ + An error when trying to unify two Type Variables + """ + def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None: + super().__init__( + f'Cannot narrow types {l} and {r}: {msg}' + ) + +class TypeConstraintBase: + """ + Base class for classes implementing a contraint on a type + """ + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase': + raise NotImplementedError('narrow', self, other) + +class TypeConstraintPrimitive(TypeConstraintBase): + """ + This contraint on a type defines its primitive shape + """ + __slots__ = ('primitive', ) + + class Primitive(enum.Enum): + """ + The primitive ID + """ + INT = 0 + FLOAT = 1 + + STATIC_ARRAY = 10 + + primitive: Primitive + + def __init__(self, primitive: Primitive) -> None: + self.primitive = primitive + + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive': + if not isinstance(other, TypeConstraintPrimitive): + raise Exception('Invalid comparison') + + if self.primitive != other.primitive: + raise TypingNarrowProtoError('Primitive does not match') + + return TypeConstraintPrimitive(self.primitive) + + def __repr__(self) -> str: + return f'Primitive={self.primitive.name}' + +class TypeConstraintSigned(TypeConstraintBase): + """ + Contraint on whether a signed value can be used or not, or whether + a value can be used in a signed expression + """ + __slots__ = ('signed', ) + + signed: Optional[bool] + + def __init__(self, signed: Optional[bool]) -> None: + self.signed = signed + + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSigned': + if not isinstance(other, TypeConstraintSigned): + raise Exception('Invalid comparison') + + if other.signed is None: + return TypeConstraintSigned(self.signed) + if self.signed is None: + return TypeConstraintSigned(other.signed) + + if self.signed is not other.signed: + raise TypingNarrowProtoError('Signed does not match') + + return TypeConstraintSigned(self.signed) + + def __repr__(self) -> str: + return f'Signed={self.signed}' + +class TypeConstraintBitWidth(TypeConstraintBase): + """ + Contraint on how many bits an expression has or can possibly have + """ + __slots__ = ('oneof', ) + + oneof: Set[int] + + def __init__(self, *, oneof: Optional[Iterable[int]] = None, minb: Optional[int] = None, maxb: Optional[int] = None) -> None: + # For now, support up to 64 bits values + self.oneof = set(oneof) if oneof is not None else set(range(1, 65)) + + if minb is not None: + self.oneof = { + x + for x in self.oneof + if minb <= x + } + + if maxb is not None: + self.oneof = { + x + for x in self.oneof + if x <= maxb + } + + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': + if not isinstance(other, TypeConstraintBitWidth): + raise Exception('Invalid comparison') + + new_oneof = self.oneof & other.oneof + + if not new_oneof: + raise TypingNarrowProtoError('Memory width cannot be resolved') + + return TypeConstraintBitWidth(oneof=new_oneof) + + def __repr__(self) -> str: + result = 'BitWidth=' + + items = list(sorted(self.oneof)) + if not items: + return result + + while items: + itm = items.pop(0) + result += str(itm) + + cnt = 0 + while cnt < len(items) and items[cnt] == itm + cnt + 1: + cnt += 1 + + if cnt == 1: + result += ',' + str(items[0]) + elif cnt > 1: + result += '..' + str(items[cnt - 1]) + + items = items[cnt:] + if items: + result += ',' + + return result + +class TypeConstraintSubscript(TypeConstraintBase): + """ + Contraint on allowing a type to be subscripted + """ + __slots__ = ('members', ) + + members: List['TypeVar'] + + def __init__(self, *, members: Iterable['TypeVar']) -> None: + self.members = list(members) + + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSubscript': + if not isinstance(other, TypeConstraintSubscript): + raise Exception('Invalid comparison') + + if len(self.members) != len(other.members): + raise TypingNarrowProtoError('Member count does not match') + + newmembers = [] + for smb, omb in zip(self.members, other.members): + nmb = ctx.new_var() + ctx.unify(nmb, smb) + ctx.unify(nmb, omb) + newmembers.append(nmb) + + return TypeConstraintSubscript(members=newmembers) + + def __repr__(self) -> str: + return 'Subscript=(' + ','.join(map(repr, self.members)) + ')' + +TTypeConstraintClass = MyPyTypeVar('TTypeConstraintClass', bound=TypeConstraintBase) + +class TypeVar: + """ + A type variable + """ + # FIXME: Explain the type system + __slots__ = ('ctx', 'ctx_id', ) + + ctx: 'Context' + ctx_id: int + + def __init__(self, ctx: 'Context', ctx_id: int) -> None: + self.ctx = ctx + self.ctx_id = ctx_id + + def add_constraint(self, newconst: TypeConstraintBase) -> None: + csts = self.ctx.var_constraints[self.ctx_id] + + if newconst.__class__ in csts: + csts[newconst.__class__] = csts[newconst.__class__].narrow(self.ctx, newconst) + else: + csts[newconst.__class__] = newconst + + def get_constraint(self, const_type: Type[TTypeConstraintClass]) -> Optional[TTypeConstraintClass]: + csts = self.ctx.var_constraints[self.ctx_id] + + res = csts.get(const_type, None) + assert res is None or isinstance(res, const_type) # type hint + return res + + def add_location(self, ref: str) -> None: + self.ctx.var_locations[self.ctx_id].add(ref) + + def __repr__(self) -> str: + return ( + 'TypeVar<' + + '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values())) + + '; locations: ' + + ', '.join(sorted(self.ctx.var_locations[self.ctx_id])) + + '>' + ) + +class Context: + """ + The context for a collection of type variables + """ + def __init__(self) -> None: + # Variables are unified (or entangled, if you will) + # that means that each TypeVar within a context has an ID, + # and all TypeVars with the same ID are the same TypeVar, + # even if they are a different instance + self.next_ctx_id = 1 + self.vars_by_id: Dict[int, List[TypeVar]] = {} + + # Store the TypeVar properties as a lookup + # so we can update these when unifying + self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {} + self.var_locations: Dict[int, Set[str]] = {} + + def new_var(self) -> TypeVar: + ctx_id = self.next_ctx_id + self.next_ctx_id += 1 + + result = TypeVar(self, ctx_id) + + self.vars_by_id[ctx_id] = [result] + self.var_constraints[ctx_id] = {} + self.var_locations[ctx_id] = set() + + return result + + def unify(self, l: Optional[TypeVar], r: Optional[TypeVar]) -> None: + # FIXME: Write method doc, find out why pylint doesn't error + + assert l is not None, ASSERTION_ERROR + assert r is not None, ASSERTION_ERROR + + assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return + + # Backup some values that we'll overwrite + l_ctx_id = l.ctx_id + r_ctx_id = r.ctx_id + l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id] + + # Create a new TypeVar, with the combined contraints + # and locations of the old ones + n = self.new_var() + + try: + for const in self.var_constraints[l_ctx_id].values(): + n.add_constraint(const) + for const in self.var_constraints[r_ctx_id].values(): + n.add_constraint(const) + except TypingNarrowProtoError as exc: + raise TypingNarrowError(l, r, str(exc)) from None + + self.var_locations[n.ctx_id] = self.var_locations[l_ctx_id] | self.var_locations[r_ctx_id] + + # ## + # And unify (or entangle) the old ones + + # First update the IDs, so they all point to the new list + for type_var in l_r_var_list: + type_var.ctx_id = n.ctx_id + + # Update our registry of TypeVars by ID, so we can find them + # on the next unify + self.vars_by_id[n.ctx_id].extend(l_r_var_list) + + # Then delete the old values for the now gone variables + # Do this last, so exceptions thrown in the code above + # still have a valid context + del self.var_constraints[l_ctx_id] + del self.var_constraints[r_ctx_id] + del self.var_locations[l_ctx_id] + del self.var_locations[r_ctx_id] + +def simplify(inp: TypeVar) -> Optional[str]: + """ + Simplifies a TypeVar into a string that wasm can work with + and users can recognize + + Should round trip with from_str + """ + tc_prim = inp.get_constraint(TypeConstraintPrimitive) + tc_bits = inp.get_constraint(TypeConstraintBitWidth) + tc_sign = inp.get_constraint(TypeConstraintSigned) + + if tc_prim is None: + return None + + primitive = tc_prim.primitive + if primitive is TypeConstraintPrimitive.Primitive.INT: + if tc_bits is None or tc_sign is None: + return None + + if tc_sign.signed is None or len(tc_bits.oneof) != 1: + return None + + bitwidth = next(iter(tc_bits.oneof)) + if bitwidth not in (8, 32, 64): + return None + + base = 'i' if tc_sign.signed else 'u' + return f'{base}{bitwidth}' + + if primitive is TypeConstraintPrimitive.Primitive.FLOAT: + if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint + return None + + if len(tc_bits.oneof) != 1: + return None + + bitwidth = next(iter(tc_bits.oneof)) + if bitwidth not in (32, 64): + return None + + return f'f{bitwidth}' + + if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + tc_subs = inp.get_constraint(TypeConstraintSubscript) + assert tc_subs is not None + assert tc_subs.members + + sab = simplify(tc_subs.members[0]) + if sab is None: + return None + + return f'{sab}[{len(tc_subs.members)}]' + + return None + +def make_u8(ctx: Context) -> TypeVar: + """ + Makes a u8 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u8') + return result + +def make_u32(ctx: Context) -> TypeVar: + """ + Makes a u32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u32') + return result + +def make_u64(ctx: Context) -> TypeVar: + """ + Makes a u64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u64') + return result + +def make_i32(ctx: Context) -> TypeVar: + """ + Makes a i32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location('i32') + return result + +def make_i64(ctx: Context) -> TypeVar: + """ + Makes a i64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location('i64') + return result + +def make_f32(ctx: Context) -> TypeVar: + """ + Makes a f32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_location('f32') + return result + +def make_f64(ctx: Context) -> TypeVar: + """ + Makes a f64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_location('f64') + return result + +BUILTIN_TYPES: Dict[str, Callable[[Context], TypeVar]] = { + 'u8': make_u8, + 'u32': make_u32, + 'u64': make_u64, + 'i32': make_i32, + 'i64': make_i64, + 'f32': make_f32, + 'f64': make_f64, +} + +TYPE_MATCH_STATIC_ARRAY = re.compile(r'^([uif][0-9]+)\[([0-9]+)\]') + +def from_str(ctx: Context, inp: str, location: Optional[str] = None) -> TypeVar: + """ + Creates a new TypeVar from the string + + Should round trip with simplify + + The location is a reference to where you found the string + in the source code. + + This could be conidered part of parsing. Though that would give trouble + with the context creation. + """ + if inp in BUILTIN_TYPES: + result = BUILTIN_TYPES[inp](ctx) + if location is not None: + result.add_location(location) + return result + + match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp) + if match: + result = ctx.new_var() + + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY)) + result.add_constraint(TypeConstraintSubscript(members=( + # Make copies so they don't get entangled + # with each other. + from_str(ctx, match[1], match[1]) + for _ in range(int(match[2])) + ))) + + result.add_location(inp) + + if location is not None: + result.add_location(location) + + return result + + raise NotImplementedError(from_str, inp) diff --git a/phasm/wasmeasy.py b/phasm/wasmeasy.py index d0cf358..01c5d0c 100644 --- a/phasm/wasmeasy.py +++ b/phasm/wasmeasy.py @@ -1,7 +1,7 @@ """ Helper functions to quickly generate WASM code """ -from typing import Any, Dict, List, Optional, Type +from typing import List, Optional import functools diff --git a/pylintrc b/pylintrc index 0591be3..f872f8d 100644 --- a/pylintrc +++ b/pylintrc @@ -1,5 +1,5 @@ [MASTER] -disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 +disable=C0103,C0122,R0902,R0903,R0911,R0912,R0913,R0915,R1710,W0223 max-line-length=180 @@ -7,4 +7,4 @@ max-line-length=180 good-names=g [tests] -disable=C0116, +disable=C0116,R0201 diff --git a/tests/integration/constants.py b/tests/integration/constants.py new file mode 100644 index 0000000..07dacbe --- /dev/null +++ b/tests/integration/constants.py @@ -0,0 +1,16 @@ +""" +Constants for use in the tests +""" + +ALL_INT_TYPES = ['u8', 'u32', 'u64', 'i32', 'i64'] +COMPLETE_INT_TYPES = ['u32', 'u64', 'i32', 'i64'] + +ALL_FLOAT_TYPES = ['f32', 'f64'] +COMPLETE_FLOAT_TYPES = ALL_FLOAT_TYPES + +TYPE_MAP = { + **{x: int for x in ALL_INT_TYPES}, + **{x: float for x in ALL_FLOAT_TYPES}, +} + +COMPLETE_PRIMITIVE_TYPES = COMPLETE_INT_TYPES + COMPLETE_FLOAT_TYPES diff --git a/tests/integration/runners.py b/tests/integration/runners.py index fd3a53e..77cd3f5 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -13,6 +13,7 @@ import wasmtime from phasm.compiler import phasm_compile from phasm.parser import phasm_parse +from phasm.typer import phasm_type from phasm import ourlang from phasm import wasm @@ -40,6 +41,7 @@ class RunnerBase: Parses the Phasm code into an AST """ self.phasm_ast = phasm_parse(self.phasm_code) + phasm_type(self.phasm_ast) def compile_ast(self) -> None: """ diff --git a/tests/integration/test_code/__init__.py b/tests/integration/test_code/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_code/test_typing.py b/tests/integration/test_code/test_typing.py new file mode 100644 index 0000000..7be1140 --- /dev/null +++ b/tests/integration/test_code/test_typing.py @@ -0,0 +1,20 @@ +import pytest + +from phasm import typing as sut + +class TestTypeConstraintBitWidth: + @pytest.mark.parametrize('oneof,exp', [ + (set(), '', ), + ({1}, '1', ), + ({1,2}, '1,2', ), + ({1,2,3}, '1..3', ), + ({1,2,3,4}, '1..4', ), + + ({1,3}, '1,3', ), + ({1,4}, '1,4', ), + + ({1,2,3,4,6,7,8,9}, '1..4,6..9', ), + ]) + def test_repr(self, oneof, exp): + mut_self = sut.TypeConstraintBitWidth(oneof=oneof) + assert ('BitWidth=' + exp) == repr(mut_self) diff --git a/tests/integration/test_examples/__init__.py b/tests/integration/test_examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples/test_crc32.py similarity index 93% rename from tests/integration/test_examples.py rename to tests/integration/test_examples/test_crc32.py index b3b278d..d9f74bf 100644 --- a/tests/integration/test_examples.py +++ b/tests/integration/test_examples/test_crc32.py @@ -3,9 +3,9 @@ import struct import pytest -from .helpers import Suite +from ..helpers import Suite -@pytest.mark.integration_test +@pytest.mark.slow_integration_test def test_crc32(): # FIXME: Stub # crc = 0xFFFFFFFF diff --git a/tests/integration/test_fib.py b/tests/integration/test_examples/test_fib.py similarity index 94% rename from tests/integration/test_fib.py rename to tests/integration/test_examples/test_fib.py index 20e7e63..3f99a46 100644 --- a/tests/integration/test_fib.py +++ b/tests/integration/test_examples/test_fib.py @@ -1,6 +1,6 @@ import pytest -from .helpers import Suite +from ..helpers import Suite @pytest.mark.slow_integration_test def test_fib(): diff --git a/tests/integration/test_helper.py b/tests/integration/test_helper.py deleted file mode 100644 index cb44021..0000000 --- a/tests/integration/test_helper.py +++ /dev/null @@ -1,70 +0,0 @@ -import io - -import pytest - -from pywasm import binary -from pywasm import Runtime - -from wasmer import wat2wasm - -def run(code_wat): - code_wasm = wat2wasm(code_wat) - module = binary.Module.from_reader(io.BytesIO(code_wasm)) - - runtime = Runtime(module, {}, {}) - - out_put = runtime.exec('testEntry', []) - return (runtime, out_put) - -@pytest.mark.parametrize('size,offset,exp_out_put', [ - ('32', 0, 0x3020100), - ('32', 1, 0x4030201), - ('64', 0, 0x706050403020100), - ('64', 2, 0x908070605040302), -]) -def test_i32_64_load(size, offset, exp_out_put): - code_wat = f""" - (module - (memory 1) - (data (memory 0) (i32.const 0) "\\00\\01\\02\\03\\04\\05\\06\\07\\08\\09\\10") - - (func (export "testEntry") (result i{size}) - i32.const {offset} - i{size}.load - return )) -""" - - (_, out_put) = run(code_wat) - assert exp_out_put == out_put - -def test_load_then_store(): - code_wat = """ - (module - (memory 1) - (data (memory 0) (i32.const 0) "\\04\\00\\00\\00") - - (func (export "testEntry") (result i32) (local $my_memory_value i32) - ;; Load i32 from address 0 - i32.const 0 - i32.load - - ;; Add 8 to the loaded value - i32.const 8 - i32.add - - local.set $my_memory_value - - ;; Store back to the memory - i32.const 0 - local.get $my_memory_value - i32.store - - ;; Return something - i32.const 9 - return )) -""" - (runtime, out_put) = run(code_wat) - - assert 9 == out_put - - assert (b'\x0c'+ b'\00' * 23) == runtime.store.mems[0].data[:24] diff --git a/tests/integration/test_lang/__init__.py b/tests/integration/test_lang/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_builtins.py b/tests/integration/test_lang/test_builtins.py similarity index 95% rename from tests/integration/test_builtins.py rename to tests/integration/test_lang/test_builtins.py index 4b84197..2b06afd 100644 --- a/tests/integration/test_builtins.py +++ b/tests/integration/test_lang/test_builtins.py @@ -2,8 +2,8 @@ import sys import pytest -from .helpers import Suite, write_header -from .runners import RunnerPywasm +from ..helpers import Suite, write_header +from ..runners import RunnerPywasm def setup_interpreter(phash_code: str) -> RunnerPywasm: runner = RunnerPywasm(phash_code) diff --git a/tests/integration/test_lang/test_bytes.py b/tests/integration/test_lang/test_bytes.py new file mode 100644 index 0000000..45fc2a1 --- /dev/null +++ b/tests/integration/test_lang/test_bytes.py @@ -0,0 +1,53 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.integration_test +def test_bytes_address(): + code_py = """ +@exported +def testEntry(f: bytes) -> bytes: + return f +""" + + result = Suite(code_py).run_code(b'This is a test') + + # THIS DEPENDS ON THE ALLOCATOR + # A different allocator will return a different value + assert 20 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_length(): + code_py = """ +@exported +def testEntry(f: bytes) -> i32: + return len(f) +""" + + result = Suite(code_py).run_code(b'This is another test') + + assert 20 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_index(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return f[8] +""" + + result = Suite(code_py).run_code(b'This is another test') + + assert 0x61 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_index_out_of_bounds(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return f[50] +""" + + result = Suite(code_py).run_code(b'Short', b'Long' * 100) + + assert 0 == result.returned_value diff --git a/tests/integration/test_lang/test_if.py b/tests/integration/test_lang/test_if.py new file mode 100644 index 0000000..5d77eb9 --- /dev/null +++ b/tests/integration/test_lang/test_if.py @@ -0,0 +1,71 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.integration_test +@pytest.mark.parametrize('inp', [9, 10, 11, 12]) +def test_if_simple(inp): + code_py = """ +@exported +def testEntry(a: i32) -> i32: + if a > 10: + return 15 + + return 3 +""" + exp_result = 15 if inp > 10 else 3 + + suite = Suite(code_py) + + result = suite.run_code(inp) + assert exp_result == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.skip('Such a return is not how things should be') +def test_if_complex(): + code_py = """ +@exported +def testEntry(a: i32) -> i32: + if a > 10: + return 10 + elif a > 0: + return a + else: + return 0 + + return -1 # Required due to function type +""" + + suite = Suite(code_py) + + assert 10 == suite.run_code(20).returned_value + assert 10 == suite.run_code(10).returned_value + + assert 8 == suite.run_code(8).returned_value + + assert 0 == suite.run_code(0).returned_value + assert 0 == suite.run_code(-1).returned_value + +@pytest.mark.integration_test +def test_if_nested(): + code_py = """ +@exported +def testEntry(a: i32, b: i32) -> i32: + if a > 11: + if b > 11: + return 3 + + return 2 + + if b > 11: + return 1 + + return 0 +""" + + suite = Suite(code_py) + + assert 3 == suite.run_code(20, 20).returned_value + assert 2 == suite.run_code(20, 10).returned_value + assert 1 == suite.run_code(10, 20).returned_value + assert 0 == suite.run_code(10, 10).returned_value diff --git a/tests/integration/test_lang/test_interface.py b/tests/integration/test_lang/test_interface.py new file mode 100644 index 0000000..cbacd73 --- /dev/null +++ b/tests/integration/test_lang/test_interface.py @@ -0,0 +1,27 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.integration_test +def test_imported(): + code_py = """ +@imported +def helper(mul: i32) -> i32: + pass + +@exported +def testEntry() -> i32: + return helper(2) +""" + + def helper(mul: int) -> int: + return 4238 * mul + + result = Suite(code_py).run_code( + runtime='wasmer', + imports={ + 'helper': helper, + } + ) + + assert 8476 == result.returned_value diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py new file mode 100644 index 0000000..01df9cf --- /dev/null +++ b/tests/integration/test_lang/test_primitives.py @@ -0,0 +1,376 @@ +import pytest + +from phasm.exceptions import TypingError + +from ..helpers import Suite +from ..constants import ALL_INT_TYPES, ALL_FLOAT_TYPES, COMPLETE_INT_TYPES, TYPE_MAP + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_expr_constant_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 13 +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_expr_constant_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 32.125 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_expr_constant_entanglement(): + code_py = """ +@exported +def testEntry() -> u8: + return 1000 +""" + + with pytest.raises(TypingError, match='u8.*1000'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_module_constant_int(type_): + code_py = f""" +CONSTANT: {type_} = 13 + +@exported +def testEntry() -> {type_}: + return CONSTANT +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_module_constant_float(type_): + code_py = f""" +CONSTANT: {type_} = 32.125 + +@exported +def testEntry() -> {type_}: + return CONSTANT +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + +@pytest.mark.integration_test +def test_module_constant_entanglement(): + code_py = """ +CONSTANT: u8 = 1000 + +@exported +def testEntry() -> u32: + return 14 +""" + + with pytest.raises(TypingError, match='u8.*1000'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation +def test_logical_left_shift(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 << 3 +""" + + result = Suite(code_py).run_code() + + assert 80 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u32', 'u64']) +def test_logical_right_shift_left_bit_zero(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 >> 3 +""" + + result = Suite(code_py).run_code() + + assert 1 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_logical_right_shift_left_bit_one(): + code_py = """ +@exported +def testEntry() -> u32: + return 4294967295 >> 16 +""" + + result = Suite(code_py).run_code() + + assert 0xFFFF == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_or(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 | 3 +""" + + result = Suite(code_py).run_code() + + assert 11 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_xor(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 ^ 3 +""" + + result = Suite(code_py).run_code() + + assert 9 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_and(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 & 3 +""" + + result = Suite(code_py).run_code() + + assert 2 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_addition_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 + 3 +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_addition_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 32.0 + 0.125 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_subtraction_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 - 3 +""" + + result = Suite(code_py).run_code() + + assert 7 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_subtraction_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 100.0 - 67.875 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +# TODO: Multiplication +# TODO: Division + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['f32', 'f64']) +def test_builtins_sqrt(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return sqrt(25.0) +""" + + result = Suite(code_py).run_code() + + assert 5 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', TYPE_MAP.keys()) +def test_function_argument(type_): + code_py = f""" +@exported +def testEntry(a: {type_}) -> {type_}: + return a +""" + + result = Suite(code_py).run_code(125) + + assert 125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.skip('TODO') +def test_explicit_positive_number(): + code_py = """ +@exported +def testEntry() -> i32: + return +523 +""" + + result = Suite(code_py).run_code() + + assert 523 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.skip('TODO') +def test_explicit_negative_number(): + code_py = """ +@exported +def testEntry() -> i32: + return -19 +""" + + result = Suite(code_py).run_code() + + assert -19 == result.returned_value + +@pytest.mark.integration_test +def test_call_no_args(): + code_py = """ +def helper() -> i32: + return 19 + +@exported +def testEntry() -> i32: + return helper() +""" + + result = Suite(code_py).run_code() + + assert 19 == result.returned_value + +@pytest.mark.integration_test +def test_call_pre_defined(): + code_py = """ +def helper(left: i32, right: i32) -> i32: + return left + right + +@exported +def testEntry() -> i32: + return helper(10, 3) +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +def test_call_post_defined(): + code_py = """ +@exported +def testEntry() -> i32: + return helper(10, 3) + +def helper(left: i32, right: i32) -> i32: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 7 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_call_with_expression_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper(10 + 20, 3 + 5) + +def helper(left: {type_}, right: {type_}) -> {type_}: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 22 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_call_with_expression_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper(10.078125 + 90.046875, 63.0 + 5.0) + +def helper(left: {type_}, right: {type_}) -> {type_}: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_call_invalid_type(): + code_py = """ +def helper() -> i64: + return 19 + +@exported +def testEntry() -> i32: + return helper() +""" + + with pytest.raises(TypingError, match=r'i32.*i64'): + Suite(code_py).run_code() diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py new file mode 100644 index 0000000..9f549da --- /dev/null +++ b/tests/integration/test_lang/test_static_array.py @@ -0,0 +1,173 @@ +import pytest + +from phasm.exceptions import StaticError, TypingError + +from ..constants import ( + ALL_FLOAT_TYPES, ALL_INT_TYPES, COMPLETE_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP +) +from ..helpers import Suite + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_module_constant(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return CONSTANT[0] +""" + + result = Suite(code_py).run_code() + + assert 24 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.skip('To decide: What to do on out of index?') +@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +def test_static_array_indexed(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT, 0, 1, 2) + +def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}: + return array[i0] + array[i1] + array[i2] +""" + + result = Suite(code_py).run_code() + + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_function_call_int(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT) + +def helper(array: {type_}[3]) -> {type_}: + return array[0] + array[1] + array[2] +""" + + result = Suite(code_py).run_code() + + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_function_call_float(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24.0, 57.5, 80.75, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT) + +def helper(array: {type_}[3]) -> {type_}: + return array[0] + array[1] + array[2] +""" + + result = Suite(code_py).run_code() + + assert 162.25 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_bitwidth(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 280, ) +""" + + with pytest.raises(TypingError, match='u8.*280'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_return_as_int(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +def testEntry() -> u32: + return CONSTANT +""" + + with pytest.raises(TypingError, match=r'u32.*u8\[3\]'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_not_subscriptable(): + code_py = """ +CONSTANT: u8 = 24 + +@exported +def testEntry() -> u8: + return CONSTANT[0] +""" + + with pytest.raises(TypingError, match='Type cannot be subscripted:'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_index_out_of_range(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +@exported +def testEntry() -> u8: + return CONSTANT[3] +""" + + with pytest.raises(TypingError, match='Type cannot be subscripted with index 3:'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_static_array_constant_too_few_values(): + code_py = """ +CONSTANT: u8[3] = (24, 57, ) +""" + + with pytest.raises(TypingError, match='Member count does not match'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_static_array_constant_too_many_values(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 1, 1, ) +""" + + with pytest.raises(TypingError, match='Member count does not match'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_static_array_constant_type_mismatch(): + code_py = """ +CONSTANT: u8[3] = (24, 4000, 1, ) +""" + + with pytest.raises(TypingError, match='u8.*4000'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.skip('To decide: What to do on out of index?') +def test_static_array_index_out_of_bounds(): + code_py = """ +CONSTANT0: u32[3] = (24, 57, 80, ) + +CONSTANT1: u32[16] = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ) + +@exported +def testEntry() -> u32: + return CONSTANT0[16] +""" + + result = Suite(code_py).run_code() + + assert 0 == result.returned_value diff --git a/tests/integration/test_static_checking.py b/tests/integration/test_lang/test_struct.py similarity index 54% rename from tests/integration/test_static_checking.py rename to tests/integration/test_lang/test_struct.py index 1544537..04846e7 100644 --- a/tests/integration/test_static_checking.py +++ b/tests/integration/test_lang/test_struct.py @@ -1,18 +1,65 @@ import pytest -from phasm.parser import phasm_parse -from phasm.exceptions import StaticError +from ..helpers import Suite @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) -def test_type_mismatch_function_argument(type_): +@pytest.mark.parametrize('type_', ('i32', 'f64', )) +def test_struct_0(type_): code_py = f""" -def helper(a: {type_}) -> (i32, i32, ): - return a +class CheckedValue: + value: {type_} + +@exported +def testEntry() -> {type_}: + return helper(CheckedValue(23)) + +def helper(cv: CheckedValue) -> {type_}: + return cv.value """ - with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), a is actually {type_}'): - phasm_parse(code_py) + result = Suite(code_py).run_code() + + assert 23 == result.returned_value + +@pytest.mark.integration_test +def test_struct_1(): + code_py = """ +class Rectangle: + height: i32 + width: i32 + border: i32 + +@exported +def testEntry() -> i32: + return helper(Rectangle(100, 150, 2)) + +def helper(shape: Rectangle) -> i32: + return shape.height + shape.width + shape.border +""" + + result = Suite(code_py).run_code() + + assert 252 == result.returned_value + +@pytest.mark.integration_test +def test_struct_2(): + code_py = """ +class Rectangle: + height: i32 + width: i32 + border: i32 + +@exported +def testEntry() -> i32: + return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3)) + +def helper(shape1: Rectangle, shape2: Rectangle) -> i32: + return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border +""" + + result = Suite(code_py).run_code() + + assert 545 == result.returned_value @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @@ -39,21 +86,6 @@ def testEntry(arg: ({type_}, )) -> (i32, i32, ): with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), arg\\[0\\] is actually {type_}'): phasm_parse(code_py) -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) -def test_type_mismatch_function_result(type_): - code_py = f""" -def helper() -> {type_}: - return 1 - -@exported -def testEntry() -> (i32, i32, ): - return helper() -""" - - with pytest.raises(StaticError, match=f'Static error on line 7: Expected \\(i32, i32, \\), helper actually returns {type_}'): - phasm_parse(code_py) - @pytest.mark.integration_test def test_tuple_constant_too_few_values(): code_py = """ @@ -80,30 +112,3 @@ CONSTANT: (u32, u8, u8, ) = (24, 4000, 1, ) with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_too_few_values(): - code_py = """ -CONSTANT: u8[3] = (24, 57, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_too_many_values(): - code_py = """ -CONSTANT: u8[3] = (24, 57, 1, 1, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_type_mismatch(): - code_py = """ -CONSTANT: u8[3] = (24, 4000, 1, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): - phasm_parse(code_py) diff --git a/tests/integration/test_constants.py b/tests/integration/test_lang/test_tuple.py similarity index 57% rename from tests/integration/test_constants.py rename to tests/integration/test_lang/test_tuple.py index 19f0203..5c5321e 100644 --- a/tests/integration/test_constants.py +++ b/tests/integration/test_lang/test_tuple.py @@ -1,20 +1,7 @@ import pytest -from .helpers import Suite - -@pytest.mark.integration_test -def test_i32(): - code_py = """ -CONSTANT: i32 = 13 - -@exported -def testEntry() -> i32: - return CONSTANT * 5 -""" - - result = Suite(code_py).run_code() - - assert 65 == result.returned_value +from ..constants import COMPLETE_PRIMITIVE_TYPES, TYPE_MAP +from ..helpers import Suite @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) @@ -52,36 +39,46 @@ def helper(vector: (u8, u8, u32, u32, u64, u64, )) -> u32: assert 3333 == result.returned_value @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) -def test_static_array_1(type_): +@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +def test_tuple_simple_constructor(type_): code_py = f""" -CONSTANT: {type_}[1] = (65, ) - @exported def testEntry() -> {type_}: - return helper(CONSTANT) + return helper((24, 57, 80, )) -def helper(vector: {type_}[1]) -> {type_}: - return vector[0] +def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}: + return vector[0] + vector[1] + vector[2] """ result = Suite(code_py).run_code() - assert 65 == result.returned_value + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -def test_static_array_6(): +def test_tuple_float(): code_py = """ -CONSTANT: u32[6] = (11, 22, 3333, 4444, 555555, 666666, ) - @exported -def testEntry() -> u32: - return helper(CONSTANT) +def testEntry() -> f32: + return helper((1.0, 2.0, 3.0, )) -def helper(vector: u32[6]) -> u32: - return vector[2] +def helper(v: (f32, f32, f32, )) -> f32: + return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]) """ result = Suite(code_py).run_code() - assert 3333 == result.returned_value + assert 3.74 < result.returned_value < 3.75 + +@pytest.mark.integration_test +@pytest.mark.skip('SIMD support is but a dream') +def test_tuple_i32x4(): + code_py = """ +@exported +def testEntry() -> i32x4: + return (51, 153, 204, 0, ) +""" + + result = Suite(code_py).run_code() + + assert (1, 2, 3, 0) == result.returned_value diff --git a/tests/integration/test_runtime_checks.py b/tests/integration/test_runtime_checks.py deleted file mode 100644 index 97d6542..0000000 --- a/tests/integration/test_runtime_checks.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest - -from .helpers import Suite - -@pytest.mark.integration_test -def test_bytes_index_out_of_bounds(): - code_py = """ -@exported -def testEntry(f: bytes) -> u8: - return f[50] -""" - - result = Suite(code_py).run_code(b'Short', b'Long' * 100) - - assert 0 == result.returned_value - -@pytest.mark.integration_test -def test_static_array_index_out_of_bounds(): - code_py = """ -CONSTANT0: u32[3] = (24, 57, 80, ) - -CONSTANT1: u32[16] = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ) - -@exported -def testEntry() -> u32: - return CONSTANT0[16] -""" - - result = Suite(code_py).run_code() - - assert 0 == result.returned_value diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py deleted file mode 100644 index f0c2993..0000000 --- a/tests/integration/test_simple.py +++ /dev/null @@ -1,571 +0,0 @@ -import pytest - -from .helpers import Suite - -TYPE_MAP = { - 'u8': int, - 'u32': int, - 'u64': int, - 'i32': int, - 'i64': int, - 'f32': float, - 'f64': float, -} - -COMPLETE_SIMPLE_TYPES = [ - 'u32', 'u64', - 'i32', 'i64', - 'f32', 'f64', -] - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_return(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 13 -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_addition(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 + 3 -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_subtraction(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 - 3 -""" - - result = Suite(code_py).run_code() - - assert 7 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation -def test_logical_left_shift(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 << 3 -""" - - result = Suite(code_py).run_code() - - assert 80 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_logical_right_shift(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 >> 3 -""" - - result = Suite(code_py).run_code() - - assert 1 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_or(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 | 3 -""" - - result = Suite(code_py).run_code() - - assert 11 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_xor(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 ^ 3 -""" - - result = Suite(code_py).run_code() - - assert 9 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_and(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 & 3 -""" - - result = Suite(code_py).run_code() - - assert 2 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['f32', 'f64']) -def test_buildins_sqrt(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return sqrt(25) -""" - - result = Suite(code_py).run_code() - - assert 5 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_arg(type_): - code_py = f""" -@exported -def testEntry(a: {type_}) -> {type_}: - return a -""" - - result = Suite(code_py).run_code(125) - - assert 125 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_i32_to_i64(): - code_py = """ -@exported -def testEntry(a: i32) -> i64: - return a -""" - - result = Suite(code_py).run_code(125) - - assert 125 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_i32_plus_i64(): - code_py = """ -@exported -def testEntry(a: i32, b: i64) -> i64: - return a + b -""" - - result = Suite(code_py).run_code(125, 100) - - assert 225 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_f32_to_f64(): - code_py = """ -@exported -def testEntry(a: f32) -> f64: - return a -""" - - result = Suite(code_py).run_code(125.5) - - assert 125.5 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_f32_plus_f64(): - code_py = """ -@exported -def testEntry(a: f32, b: f64) -> f64: - return a + b -""" - - result = Suite(code_py).run_code(125.5, 100.25) - - assert 225.75 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('TODO') -def test_uadd(): - code_py = """ -@exported -def testEntry() -> i32: - return +523 -""" - - result = Suite(code_py).run_code() - - assert 523 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('TODO') -def test_usub(): - code_py = """ -@exported -def testEntry() -> i32: - return -19 -""" - - result = Suite(code_py).run_code() - - assert -19 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('inp', [9, 10, 11, 12]) -def test_if_simple(inp): - code_py = """ -@exported -def testEntry(a: i32) -> i32: - if a > 10: - return 15 - - return 3 -""" - exp_result = 15 if inp > 10 else 3 - - suite = Suite(code_py) - - result = suite.run_code(inp) - assert exp_result == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Such a return is not how things should be') -def test_if_complex(): - code_py = """ -@exported -def testEntry(a: i32) -> i32: - if a > 10: - return 10 - elif a > 0: - return a - else: - return 0 - - return -1 # Required due to function type -""" - - suite = Suite(code_py) - - assert 10 == suite.run_code(20).returned_value - assert 10 == suite.run_code(10).returned_value - - assert 8 == suite.run_code(8).returned_value - - assert 0 == suite.run_code(0).returned_value - assert 0 == suite.run_code(-1).returned_value - -@pytest.mark.integration_test -def test_if_nested(): - code_py = """ -@exported -def testEntry(a: i32, b: i32) -> i32: - if a > 11: - if b > 11: - return 3 - - return 2 - - if b > 11: - return 1 - - return 0 -""" - - suite = Suite(code_py) - - assert 3 == suite.run_code(20, 20).returned_value - assert 2 == suite.run_code(20, 10).returned_value - assert 1 == suite.run_code(10, 20).returned_value - assert 0 == suite.run_code(10, 10).returned_value - -@pytest.mark.integration_test -def test_call_pre_defined(): - code_py = """ -def helper(left: i32, right: i32) -> i32: - return left + right - -@exported -def testEntry() -> i32: - return helper(10, 3) -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - -@pytest.mark.integration_test -def test_call_post_defined(): - code_py = """ -@exported -def testEntry() -> i32: - return helper(10, 3) - -def helper(left: i32, right: i32) -> i32: - return left - right -""" - - result = Suite(code_py).run_code() - - assert 7 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_call_with_expression(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return helper(10 + 20, 3 + 5) - -def helper(left: {type_}, right: {type_}) -> {type_}: - return left - right -""" - - result = Suite(code_py).run_code() - - assert 22 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.skip('Not yet implemented') -def test_assign(): - code_py = """ - -@exported -def testEntry() -> i32: - a: i32 = 8947 - return a -""" - - result = Suite(code_py).run_code() - - assert 8947 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_struct_0(type_): - code_py = f""" -class CheckedValue: - value: {type_} - -@exported -def testEntry() -> {type_}: - return helper(CheckedValue(23)) - -def helper(cv: CheckedValue) -> {type_}: - return cv.value -""" - - result = Suite(code_py).run_code() - - assert 23 == result.returned_value - -@pytest.mark.integration_test -def test_struct_1(): - code_py = """ -class Rectangle: - height: i32 - width: i32 - border: i32 - -@exported -def testEntry() -> i32: - return helper(Rectangle(100, 150, 2)) - -def helper(shape: Rectangle) -> i32: - return shape.height + shape.width + shape.border -""" - - result = Suite(code_py).run_code() - - assert 252 == result.returned_value - -@pytest.mark.integration_test -def test_struct_2(): - code_py = """ -class Rectangle: - height: i32 - width: i32 - border: i32 - -@exported -def testEntry() -> i32: - return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3)) - -def helper(shape1: Rectangle, shape2: Rectangle) -> i32: - return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border -""" - - result = Suite(code_py).run_code() - - assert 545 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_tuple_simple_constructor(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return helper((24, 57, 80, )) - -def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}: - return vector[0] + vector[1] + vector[2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -def test_tuple_float(): - code_py = """ -@exported -def testEntry() -> f32: - return helper((1.0, 2.0, 3.0, )) - -def helper(v: (f32, f32, f32, )) -> f32: - return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]) -""" - - result = Suite(code_py).run_code() - - assert 3.74 < result.returned_value < 3.75 - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_static_array_module_constant(type_): - code_py = f""" -CONSTANT: {type_}[3] = (24, 57, 80, ) - -@exported -def testEntry() -> {type_}: - return helper(CONSTANT) - -def helper(array: {type_}[3]) -> {type_}: - return array[0] + array[1] + array[2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_static_array_indexed(type_): - code_py = f""" -CONSTANT: {type_}[3] = (24, 57, 80, ) - -@exported -def testEntry() -> {type_}: - return helper(CONSTANT, 0, 1, 2) - -def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}: - return array[i0] + array[i1] + array[i2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -def test_bytes_address(): - code_py = """ -@exported -def testEntry(f: bytes) -> bytes: - return f -""" - - result = Suite(code_py).run_code(b'This is a test') - - # THIS DEPENDS ON THE ALLOCATOR - # A different allocator will return a different value - assert 20 == result.returned_value - -@pytest.mark.integration_test -def test_bytes_length(): - code_py = """ -@exported -def testEntry(f: bytes) -> i32: - return len(f) -""" - - result = Suite(code_py).run_code(b'This is another test') - - assert 20 == result.returned_value - -@pytest.mark.integration_test -def test_bytes_index(): - code_py = """ -@exported -def testEntry(f: bytes) -> u8: - return f[8] -""" - - result = Suite(code_py).run_code(b'This is another test') - - assert 0x61 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('SIMD support is but a dream') -def test_tuple_i32x4(): - code_py = """ -@exported -def testEntry() -> i32x4: - return (51, 153, 204, 0, ) -""" - - result = Suite(code_py).run_code() - - assert (1, 2, 3, 0) == result.returned_value - -@pytest.mark.integration_test -def test_imported(): - code_py = """ -@imported -def helper(mul: i32) -> i32: - pass - -@exported -def testEntry() -> i32: - return helper(2) -""" - - def helper(mul: int) -> int: - return 4238 * mul - - result = Suite(code_py).run_code( - runtime='wasmer', - imports={ - 'helper': helper, - } - ) - - assert 8476 == result.returned_value diff --git a/tests/integration/test_stdlib/__init__.py b/tests/integration/test_stdlib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_stdlib_alloc.py b/tests/integration/test_stdlib/test_alloc.py similarity index 95% rename from tests/integration/test_stdlib_alloc.py rename to tests/integration/test_stdlib/test_alloc.py index da8ccea..96d1fd6 100644 --- a/tests/integration/test_stdlib_alloc.py +++ b/tests/integration/test_stdlib/test_alloc.py @@ -2,8 +2,8 @@ import sys import pytest -from .helpers import write_header -from .runners import RunnerPywasm3 as Runner +from ..helpers import write_header +from ..runners import RunnerPywasm3 as Runner def setup_interpreter(phash_code: str) -> Runner: runner = Runner(phash_code)