From e456f55bb02e7c438a6e1ece7c9cb808452b1f36 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 24 Dec 2022 19:22:03 +0100 Subject: [PATCH] Type compare, code cleanup, extra test --- phasm/codestyle.py | 16 +- phasm/compiler.py | 186 ++++++------------ phasm/parser.py | 57 ------ phasm/type3/constraints.py | 45 +---- phasm/type3/types.py | 31 ++- .../test_lang/test_static_array.py | 20 +- 6 files changed, 116 insertions(+), 239 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index abf8e86..f1031ad 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -16,18 +16,6 @@ def phasm_render(inp: ourlang.Module) -> str: return module(inp) Statements = Generator[str, None, None] -# -# def type_var(inp: Optional[typing.TypeVar]) -> str: -# """ -# Render: type's name -# """ -# assert inp is not None, typing.ASSERTION_ERROR -# -# mtyp = typing.simplify(inp) -# if mtyp is None: -# raise NotImplementedError(f'Rendering type {inp}') -# -# return mtyp def type3(inp: Type3OrPlaceholder) -> str: """ @@ -36,13 +24,13 @@ def type3(inp: Type3OrPlaceholder) -> str: assert isinstance(inp, Type3), TYPE3_ASSERTION_ERROR if isinstance(inp, type3types.AppliedType3): - if inp.base is type3types.tuple: + if inp.base == type3types.tuple: return '(' + ', '.join( type3(x) for x in inp.args ) + ', )' - if inp.base is type3types.static_array: + if inp.base == type3types.static_array: assert 1 == len(inp.args) assert isinstance(inp.args[0], Type3), TYPE3_ASSERTION_ERROR diff --git a/phasm/compiler.py b/phasm/compiler.py index 97855fc..58e1e9c 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -40,39 +40,43 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType: """ assert isinstance(inp, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - if inp is type3types.u8: + if inp == type3types.u8: # WebAssembly has only support for 32 and 64 bits # So we need to store more memory per byte return wasm.WasmTypeInt32() - if inp is type3types.u32: + if inp == type3types.u32: return wasm.WasmTypeInt32() - if inp is type3types.u64: + if inp == type3types.u64: return wasm.WasmTypeInt64() - if inp is type3types.i32: + if inp == type3types.i32: return wasm.WasmTypeInt32() - if inp is type3types.i64: + if inp == type3types.i64: return wasm.WasmTypeInt64() - if inp is type3types.f32: + if inp == type3types.f32: return wasm.WasmTypeFloat32() - if inp is type3types.f64: + if inp == type3types.f64: return wasm.WasmTypeFloat64() - if inp is type3types.bytes: + if inp == type3types.bytes: # bytes are passed as pointer # And pointers are i32 return wasm.WasmTypeInt32() if isinstance(inp, type3types.StructType3): - # Structs and tuples are passed as pointer - # And pointers are i32 + # Structs are passed as pointer, which are i32 return wasm.WasmTypeInt32() + if isinstance(inp, type3types.AppliedType3): + if inp.base == type3types.static_array: + # Static Arrays are passed as pointer, which are i32 + return wasm.WasmTypeInt32() + raise NotImplementedError(type3, inp) # Operators that work for i32, i64, f32, f64 @@ -150,28 +154,28 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: if isinstance(inp, ourlang.ConstantPrimitive): assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - if inp.type3 is type3types.u8: + if inp.type3 == type3types.u8: # No native u8 type - treat as i32, with caution assert isinstance(inp.value, int) wgn.i32.const(inp.value) return - if inp.type3 is type3types.i32 or inp.type3 is type3types.u32: + if inp.type3 in (type3types.i32, type3types.u32, ): assert isinstance(inp.value, int) wgn.i32.const(inp.value) return - if inp.type3 is type3types.i64 or inp.type3 is type3types.u64: + if inp.type3 in (type3types.i64, type3types.u64, ): assert isinstance(inp.value, int) wgn.i64.const(inp.value) return - if inp.type3 is type3types.f32: + if inp.type3 == type3types.f32: assert isinstance(inp.value, float) wgn.f32.const(inp.value) return - if inp.type3 is type3types.f64: + if inp.type3 == type3types.f64: assert isinstance(inp.value, float) wgn.f64.const(inp.value) return @@ -193,7 +197,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: return if isinstance(inp.type3, type3types.AppliedType3): - if inp.type3.base is type3types.static_array: + if inp.type3.base == type3types.static_array: assert inp.variable.data_block is not None, 'Static arrays must be memory stored' assert inp.variable.data_block.address is not None, 'Value not allocated' wgn.i32.const(inp.variable.data_block.address) @@ -208,12 +212,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: # return # - # if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: - # assert inp.variable.data_block is not None, 'Combined values are memory stored' - # assert inp.variable.data_block.address is not None, 'Value not allocated' - # wgn.i32.const(inp.variable.data_block.address) - # return - assert inp.variable.data_block is None, 'Primitives are not memory stored' expression(wgn, inp.variable.constant) @@ -228,46 +226,46 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR # FIXME: Re-implement build-in operators - if inp.type3 is type3types.u8: + if inp.type3 == type3types.u8: if operator := U8_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if inp.type3 is type3types.u32: + if inp.type3 == type3types.u32: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := U32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if inp.type3 is type3types.u64: + if inp.type3 == type3types.u64: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := U64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if inp.type3 is type3types.i32: + if inp.type3 == type3types.i32: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := I32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if inp.type3 is type3types.i64: + if inp.type3 == type3types.i64: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := I64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if inp.type3 is type3types.f32: + if inp.type3 == type3types.f32: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f32.{operator}') return if operator := F32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f32.{operator}') return - if inp.type3 is type3types.f64: + if inp.type3 == type3types.f64: if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f64.{operator}') return @@ -282,18 +280,18 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - if inp.type3 is type3types.f32: + if inp.type3 == type3types.f32: if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f32.{inp.operator}') return - if inp.type3 is type3types.f64: + if inp.type3 == type3types.f64: if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f64.{inp.operator}') return - if inp.type3 is type3types.u32: + if inp.type3 == type3types.u32: if inp.operator == 'len': - if inp.right.type3 is type3types.bytes: + if inp.right.type3 == type3types.bytes: wgn.i32.load() return @@ -316,74 +314,32 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: assert isinstance(inp.varref.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR if isinstance(inp.varref.type3, type3types.AppliedType3): - if inp.varref.type3.base is type3types.static_array: + if inp.varref.type3.base == type3types.static_array: assert 1 == len(inp.varref.type3.args) - assert isinstance(inp.varref.type3.args[0], type3types.Type3) + el_type = inp.varref.type3.args[0] + assert isinstance(el_type, type3types.Type3) + + # OPTIMIZE: If index is a constant, we can use offset instead of multiply + + # FIXME: Out of bounds check expression(wgn, inp.varref) expression(wgn, inp.index) - wgn.i32.const(_calculate_alloc_size(inp.varref.type3.args[0])) + wgn.i32.const(_calculate_alloc_size(el_type)) wgn.i32.mul() wgn.i32.add() + + mtyp = LOAD_STORE_TYPE_MAP.get(el_type.name) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, el_type) + + wgn.add_statement(f'{mtyp}.load') return - # - # assert inp.varref.type_var is not None, typing.ASSERTION_ERROR - # tc_type = inp.varref.type_var.get_type() - # if tc_type is None: - # raise NotImplementedError(expression, inp, inp.varref.type_var) - - # if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: - # if not isinstance(inp.index, ourlang.ConstantPrimitive): - # raise NotImplementedError(expression, inp, inp.index) - # if not isinstance(inp.index.value, int): - # raise NotImplementedError(expression, inp, inp.index.value) - # - # assert inp.type_var is not None, typing.ASSERTION_ERROR - # mtyp = typing.simplify(inp.type_var) - # if mtyp is None: - # raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp) - # - # if mtyp == 'u8': - # # u8 operations are done using i32, since WASM does not have u8 operations - # mtyp = 'i32' - # elif mtyp == 'u32': - # # u32 operations are done using i32, using _u operations - # mtyp = 'i32' - # elif mtyp == 'u64': - # # u64 operations are done using i64, using _u operations - # mtyp = 'i64' - # - # tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript) - # if tc_subs is None: - # raise NotImplementedError(expression, inp, inp.varref.type_var) - # - # assert 0 < len(tc_subs.members) - # tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth) - # if tc_bits is None or len(tc_bits.oneof) > 1: - # raise NotImplementedError(expression, inp, inp.varref.type_var) - # - # bitwidth = next(iter(tc_bits.oneof)) - # if bitwidth % 8 != 0: - # raise NotImplementedError(expression, inp, inp.varref.type_var) - # - # expression(wgn, inp.varref) - # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value)) - # return - raise NotImplementedError(expression, inp, inp.varref.type3) - - # TODO: Broken after new type system - # if isinstance(inp, ourlang.AccessBytesIndex): - # if not isinstance(inp.type, typing.TypeUInt8): - # raise NotImplementedError(inp, inp.type) - # - # expression(wgn, inp.varref) - # expression(wgn, inp.index) - # wgn.call(stdlib_types.__subscript_bytes__) - # return - if isinstance(inp, ourlang.AccessStructMember): mtyp = LOAD_STORE_TYPE_MAP.get(inp.struct_type3.members[inp.member].name) if mtyp is None: @@ -397,32 +353,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: ))) return - # if isinstance(inp, ourlang.AccessTupleMember): - # mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) - # if mtyp is None: - # # In the future might extend this by having structs or tuples - # # as members of struct or tuples - # raise NotImplementedError(expression, inp, inp.member) - # - # expression(wgn, inp.varref) - # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - # return - # - # if isinstance(inp, ourlang.AccessStaticArrayMember): - # mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__) - # if mtyp is None: - # # In the future might extend this by having structs or tuples - # # as members of static arrays - # raise NotImplementedError(expression, inp, inp.member) - # - # expression(wgn, inp.varref) - # expression(wgn, inp.member) - # wgn.i32.const(inp.static_array.member_type.alloc_size()) - # wgn.i32.mul() - # wgn.i32.add() - # wgn.add_statement(f'{mtyp}.load') - # return - if isinstance(inp, ourlang.Fold): expression_fold(wgn, inp) return @@ -662,37 +592,37 @@ def module_data(inp: ourlang.ModuleData) -> bytes: for constant in block.data: assert isinstance(constant.type3, type3types.Type3), (id(constant), type3types.TYPE3_ASSERTION_ERROR) - if constant.type3 is type3types.u8: + if constant.type3 == type3types.u8: assert isinstance(constant.value, int) data_list.append(module_data_u8(constant.value)) continue - if constant.type3 is type3types.u32: + if constant.type3 == type3types.u32: assert isinstance(constant.value, int) data_list.append(module_data_u32(constant.value)) continue - if constant.type3 is type3types.u64: + if constant.type3 == type3types.u64: assert isinstance(constant.value, int) data_list.append(module_data_u64(constant.value)) continue - if constant.type3 is type3types.i32: + if constant.type3 == type3types.i32: assert isinstance(constant.value, int) data_list.append(module_data_i32(constant.value)) continue - if constant.type3 is type3types.i64: + if constant.type3 == type3types.i64: assert isinstance(constant.value, int) data_list.append(module_data_i64(constant.value)) continue - if constant.type3 is type3types.f32: + if constant.type3 == type3types.f32: assert isinstance(constant.value, float) data_list.append(module_data_f32(constant.value)) continue - if constant.type3 is type3types.f64: + if constant.type3 == type3types.f64: assert isinstance(constant.value, float) data_list.append(module_data_f64(constant.value)) continue @@ -797,13 +727,13 @@ def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstruc wgn.local.get(tmp_var) def _calculate_alloc_size(typ: Union[type3types.StructType3, type3types.Type3]) -> int: - if typ is type3types.u8: - return 1 + if typ == type3types.u8: + return 4 # FIXME: We allocate 4 bytes for every u8 since you load them into an i32 - if typ is type3types.u32 or typ is type3types.i32 or typ is type3types.f32: + if typ in (type3types.u32, type3types.i32, type3types.f32, ): return 4 - if typ is type3types.u64 or typ is type3types.i64 or typ is type3types.f64: + if typ in (type3types.u64, type3types.i64, type3types.f64, ): return 8 if isinstance(typ, type3types.StructType3): diff --git a/phasm/parser.py b/phasm/parser.py index 44d2680..ab7b327 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -633,63 +633,6 @@ class OurVisitor: return Subscript(varref, slice_expr) - # if isinstance(node_typ, TypeBytes): - # if isinstance(varref, ModuleConstantReference): - # raise NotImplementedError(f'{node} from module constant') - # - # return AccessBytesIndex( - # varref, - # slice_expr, - # ) - # - # if isinstance(node_typ, TypeTuple): - # if not isinstance(slice_expr, ConstantPrimitive): - # _raise_static_error(node, 'Must subscript using a constant index') - # - # idx = slice_expr.value - # - # if not isinstance(idx, int): - # _raise_static_error(node, 'Must subscript using a constant integer index') - # - # if not (0 <= idx < len(node_typ.members)): - # _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') - # - # tuple_member = node_typ.members[idx] - # - # if isinstance(varref, ModuleConstantReference): - # raise NotImplementedError(f'{node} from module constant') - # - # return AccessTupleMember( - # varref, - # tuple_member, - # ) - # - # if isinstance(node_typ, TypeStaticArray): - # if not isinstance(slice_expr, ConstantPrimitive): - # return AccessStaticArrayMember( - # varref, - # node_typ, - # slice_expr, - # ) - # - # idx = slice_expr.value - # - # if not isinstance(idx, int): - # _raise_static_error(node, 'Must subscript using an integer index') - # - # if not (0 <= idx < len(node_typ.members)): - # _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') - # - # static_array_member = node_typ.members[idx] - # - # return AccessStaticArrayMember( - # varref, - # node_typ, - # static_array_member, - # ) - # - # _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') - def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive: del module diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index 65378fc..0cde42c 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -122,7 +122,7 @@ class SameTypeConstraint(ConstraintBase): first_type = known_types[0] for typ in known_types[1:]: - if typ is not first_type: + if typ != first_type: return Error(f'{typ:s} must be {first_type:s} instead') if not placeholders: @@ -265,7 +265,7 @@ class LiteralFitsConstraint(ConstraintBase): res: NewConstraintList if isinstance(self.type3, types.AppliedType3): - if self.type3.base is types.tuple: + if self.type3.base == types.tuple: if not isinstance(self.literal, ourlang.ConstantTuple): return Error('Must be tuple') @@ -285,7 +285,7 @@ class LiteralFitsConstraint(ConstraintBase): return res - if self.type3.base is types.static_array: + if self.type3.base == types.static_array: if not isinstance(self.literal, ourlang.ConstantTuple): return Error('Must be tuple') @@ -371,42 +371,17 @@ class CanBeSubscriptedConstraint(ConstraintBase): self.type3 = smap[self.type3] if isinstance(self.type3, types.AppliedType3): - if self.type3.base is types.static_array: + if self.type3.base == types.static_array: return [ SameTypeConstraint(types.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b') ] - raise NotImplementedError - # if isinstance(self.type3, types.PlaceholderForType): - # return RequireTypeSubstitutes() - # - # if isinstance(self.index_type3, types.PlaceholderForType): - # return RequireTypeSubstitutes() - # - # if not isinstance(self.type3, types.AppliedType3): - # return Error(f'Cannot subscript {self.type3:s}') - # - # if self.type3.base is types.tuple: - # return None - # - # raise NotImplementedError(self.type3) - # - # def get_new_placeholder_substitutes(self) -> SubstitutionMap: - # if isinstance(self.type3, types.AppliedType3) and self.type3.base is types.tuple and isinstance(self.index_type3, types.PlaceholderForType): - # return { - # self.index_type3: types.u32, - # } - # - # return {} - # - # def substitute_placeholders(self, smap: SubstitutionMap) -> None: # FIXME: Duplicate code - # if isinstance(self.type3, types.PlaceholderForType) and self.type3 in smap: # FIXME: Check recursive? - # self.type3.get_substituted(smap[self.type3]) - # self.type3 = smap[self.type3] - # - # if isinstance(self.index_type3, types.PlaceholderForType) and self.index_type3 in smap: # FIXME: Check recursive? - # self.index_type3.get_substituted(smap[self.index_type3]) - # self.index_type3 = smap[self.index_type3] + # FIXME: bytes + + if self.type3.name in types.LOOKUP_TABLE: + return Error(f'{self.type3.name} cannot be subscripted') + + raise NotImplementedError(self.type3) def human_readable(self) -> HumanReadableRet: return ( diff --git a/phasm/type3/types.py b/phasm/type3/types.py index a91cab1..d56f86a 100644 --- a/phasm/type3/types.py +++ b/phasm/type3/types.py @@ -45,10 +45,16 @@ class Type3: return str(self) def __eq__(self, other: Any) -> bool: - raise NotImplementedError + if isinstance(other, PlaceholderForType): + return False + + if not isinstance(other, Type3): + raise NotImplementedError + + return self is other def __ne__(self, other: Any) -> bool: - raise NotImplementedError + return not self.__eq__(other) def __hash__(self) -> int: raise NotImplementedError @@ -82,13 +88,16 @@ class PlaceholderForType: return str(self) def __eq__(self, other: Any) -> bool: + if isinstance(other, Type3): + return False + if not isinstance(other, PlaceholderForType): raise NotImplementedError return self is other def __ne__(self, other: Any) -> bool: - raise NotImplementedError + return not self.__eq__(other) def __hash__(self) -> int: return 0 # Valid but performs badly @@ -127,6 +136,22 @@ class AppliedType3(Type3): self.base = base self.args = args + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Type3): + raise NotImplementedError + + if not isinstance(other, AppliedType3): + return False + + return ( + self.base == other.base + and len(self.args) == len(other.args) + and all( + s == x + for s, x in zip(self.args, other.args) + ) + ) + def __repr__(self) -> str: return f'AppliedType3({repr(self.base)}, {repr(self.args)})' diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py index cb0e4a9..7b3ce41 100644 --- a/tests/integration/test_lang/test_static_array.py +++ b/tests/integration/test_lang/test_static_array.py @@ -34,7 +34,7 @@ def testEntry() -> {type_}: result = Suite(code_py).run_code() - assert 24 == result.returned_value + assert 57 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test @@ -95,6 +95,22 @@ def helper(array: {type_}[3]) -> {type_}: assert 162.25 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +def test_function_call_element(): + code_py = """ +CONSTANT: u64[3] = (250, 250000, 250000000, ) + +@exported +def testEntry() -> u8: + return helper(CONSTANT[0]) + +def helper(x: u8) -> u8: + return x +""" + + with pytest.raises(Type3Exception, match=r'u8 must be u64 instead'): + Suite(code_py).run_code() + @pytest.mark.integration_test def test_module_constant_type_mismatch_bitwidth(): code_py = """ @@ -126,7 +142,7 @@ def testEntry() -> u8: return CONSTANT[0] """ - with pytest.raises(Type3Exception, match='Type cannot be subscripted:'): + with pytest.raises(Type3Exception, match='u8 cannot be subscripted'): Suite(code_py).run_code() @pytest.mark.integration_test