diff --git a/phasm/compiler.py b/phasm/compiler.py index 3bb8d21..2e214a0 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -29,6 +29,16 @@ LOAD_STORE_TYPE_MAP = { # For now this is nice & clean, but this will get messy quick # Especially once we get functions with polymorphying applied types INSTANCES = { + type3classes.Eq.operators['==']: { + 'a=u8': stdlib_types.u8_eq_equals, + 'a=u32': stdlib_types.u32_eq_equals, + 'a=u64': stdlib_types.u64_eq_equals, + 'a=i8': stdlib_types.i8_eq_equals, + 'a=i32': stdlib_types.i32_eq_equals, + 'a=i64': stdlib_types.i64_eq_equals, + 'a=f32': stdlib_types.f32_eq_equals, + 'a=f64': stdlib_types.f64_eq_equals, + }, type3classes.Num.operators['+']: { 'a=u32': stdlib_types.u32_num_add, 'a=u64': stdlib_types.u64_num_add, @@ -74,6 +84,11 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType: if inp == type3types.none: return wasm.WasmTypeNone() + if inp == type3types.bool_: + # WebAssembly stores booleans as i32 + # See e.g. f32.eq, which is [f32 f32] -> [i32] + return wasm.WasmTypeInt32() + if inp == type3types.u8: # WebAssembly has only support for 32 and 64 bits # So we need to store more memory per byte @@ -122,11 +137,6 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType: raise NotImplementedError(type3, inp) -# Operators that work for i32, i64, f32, f64 -OPERATOR_MAP = { - '==': 'eq', -} - U8_OPERATOR_MAP = { # Under the hood, this is an i32 # Implementing Right Shift XOR, OR, AND is fine since the 3 remaining @@ -332,23 +342,18 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR if isinstance(inp.operator, type3classes.Type3ClassMethod): - if '=>' in inp.operator.signature: - raise NotImplementedError + type_var_map: Dict[type3classes.TypeVariable, type3types.Type3] = {} - type_var_set = inp.operator.type_vars - - type_var_map: Dict[str, type3types.Type3] = {} - - for arg_letter, arg_expr in zip(inp.operator.signature_parts, [inp.left, inp.right, inp]): - if arg_letter not in type_var_set: + for type_var, arg_expr in zip(inp.operator.signature, [inp.left, inp.right, inp]): + if not isinstance(type_var, type3classes.TypeVariable): # Fixed type, not part of the lookup requirements continue assert isinstance(arg_expr.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - type_var_map[arg_letter] = arg_expr.type3 + type_var_map[type_var] = arg_expr.type3 instance_key = ','.join( - f'{k}={v.name}' + f'{k.letter}={v.name}' for k, v in type_var_map.items() ) @@ -371,53 +376,32 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: if operator_annotation == '(<) :: u64 -> u64 -> bool': wgn.add_statement('i64.lt_u') return - if operator_annotation == '(==) :: u64 -> u64 -> bool': - wgn.add_statement('i64.eq') - return if inp.type3 == type3types.u8: if operator := U8_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if inp.type3 == type3types.u32: - if operator := OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i32.{operator}') - return if operator := U32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if inp.type3 == type3types.u64: - if operator := OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i64.{operator}') - return if operator := U64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if inp.type3 == type3types.i32: - if operator := OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i32.{operator}') - return if operator := I32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if inp.type3 == type3types.i64: - if operator := OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i64.{operator}') - return if operator := I64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if inp.type3 == type3types.f32: - if operator := OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'f32.{operator}') - return if operator := F32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f32.{operator}') return if inp.type3 == type3types.f64: - if operator := OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'f64.{operator}') - return if operator := F64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f64.{operator}') return diff --git a/phasm/parser.py b/phasm/parser.py index 74819f9..55b6ed4 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -36,6 +36,7 @@ from .type3 import typeclasses as type3typeclasses from .type3 import types as type3types PRELUDE_OPERATORS = { + **type3typeclasses.Eq.operators, **type3typeclasses.Num.operators, } @@ -400,6 +401,9 @@ class OurVisitor: else: raise NotImplementedError(f'Operator {node.ops}') + if operator in PRELUDE_OPERATORS: + operator = PRELUDE_OPERATORS[operator] + return BinaryOp( operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), diff --git a/phasm/stdlib/types.py b/phasm/stdlib/types.py index 7d64f54..11ed170 100644 --- a/phasm/stdlib/types.py +++ b/phasm/stdlib/types.py @@ -66,6 +66,30 @@ def __subscript_bytes__(g: Generator, adr: i32, ofs: i32) -> i32: return i32('return') # To satisfy mypy +def u8_eq_equals(g: Generator) -> None: + g.add_statement('i32.eq') + +def u32_eq_equals(g: Generator) -> None: + g.add_statement('i32.eq') + +def u64_eq_equals(g: Generator) -> None: + g.add_statement('i64.eq') + +def i8_eq_equals(g: Generator) -> None: + g.add_statement('i32.eq') + +def i32_eq_equals(g: Generator) -> None: + g.add_statement('i32.eq') + +def i64_eq_equals(g: Generator) -> None: + g.add_statement('i64.eq') + +def f32_eq_equals(g: Generator) -> None: + g.add_statement('f32.eq') + +def f64_eq_equals(g: Generator) -> None: + g.add_statement('f64.eq') + def u32_num_add(g: Generator) -> None: g.add_statement('i32.add') diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index 724c735..9aa315d 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -67,30 +67,34 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator: if isinstance(inp, ourlang.BinaryOp): if isinstance(inp.operator, type3typeclasses.Type3ClassMethod): - if '=>' in inp.operator.signature: - raise NotImplementedError - type_var_map = { x: type3types.PlaceholderForType([]) - for x in inp.operator.type_vars + for x in inp.operator.signature + if isinstance(x, type3typeclasses.TypeVariable) } yield from expression(ctx, inp.left) yield from expression(ctx, inp.right) - for arg_letter in inp.operator.type3_class.args: - assert arg_letter in type_var_map # When can this happen? + for type_var in inp.operator.type3_class.args: + assert type_var in type_var_map # When can this happen? yield MustImplementTypeClassConstraint( inp.operator.type3_class, - type_var_map[arg_letter], + type_var_map[type_var], ) - for arg_letter, arg_expr in zip(inp.operator.signature_parts, [inp.left, inp.right, inp]): - if arg_letter not in type_var_map: - raise NotImplementedError + for sig_part, arg_expr in zip(inp.operator.signature, [inp.left, inp.right, inp]): + if isinstance(sig_part, type3typeclasses.TypeVariable): + yield SameTypeConstraint(type_var_map[sig_part], arg_expr.type3) + continue - yield SameTypeConstraint(type_var_map[arg_letter], arg_expr.type3) + if isinstance(sig_part, type3typeclasses.TypeReference): + # On key error: We probably have to a lot of work to do refactoring + # the type lookups + exp_type = type3types.LOOKUP_TABLE[sig_part.name] + yield SameTypeConstraint(exp_type, arg_expr.type3) + continue return diff --git a/phasm/type3/typeclasses.py b/phasm/type3/typeclasses.py index e966521..4eb6adc 100644 --- a/phasm/type3/typeclasses.py +++ b/phasm/type3/typeclasses.py @@ -1,30 +1,62 @@ -from typing import Dict, Iterable, List, Mapping, Set +from typing import Any, Dict, Iterable, List, Mapping, Union +class TypeVariable: + __slots__ = ('letter', ) + + letter: str + + def __init__(self, letter: str) -> None: + assert len(letter) == 1, f'{letter} is not a valid type variable' + self.letter = letter + + def __hash__(self) -> int: + return hash(self.letter) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, TypeVariable): + raise NotImplementedError + + return self.letter == other.letter + + def __repr__(self) -> str: + return f'TypeVariable({repr(self.letter)})' + +class TypeReference: + __slots__ = ('name', ) + + name: str + + def __init__(self, name: str) -> None: + assert len(name) > 1, f'{name} is not a valid type reference' + self.name = name + + def __hash__(self) -> int: + return hash(self.name) + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, TypeReference): + raise NotImplementedError + + return self.name == other.name + + def __repr__(self) -> str: + return f'TypeReference({repr(self.name)})' + class Type3ClassMethod: __slots__ = ('type3_class', 'name', 'signature', ) type3_class: 'Type3Class' name: str - signature: str + signature: List[Union[TypeReference, TypeVariable]] def __init__(self, type3_class: 'Type3Class', name: str, signature: str) -> None: self.type3_class = type3_class self.name = name - self.signature = signature - - @property - def signature_parts(self) -> List[str]: - return self.signature.split(' -> ') - - @property - def type_vars(self) -> Set[str]: - return { - x - for x in self.signature_parts - if 1 == len(x) - and x == x.lower() - } + self.signature = [ + TypeVariable(x) if len(x) == 1 else TypeReference(x) + for x in signature.split(' -> ') + ] def __repr__(self) -> str: return f'Type3ClassMethod({repr(self.type3_class)}, {repr(self.name)}, {repr(self.signature)})' @@ -33,13 +65,13 @@ class Type3Class: __slots__ = ('name', 'args', 'methods', 'operators', ) name: str - args: List[str] + args: List[TypeVariable] methods: Dict[str, Type3ClassMethod] operators: Dict[str, Type3ClassMethod] def __init__(self, name: str, args: Iterable[str], methods: Mapping[str, str], operators: Mapping[str, str]) -> None: self.name = name - self.args = [*args] + self.args = [TypeVariable(x) for x in args] self.methods = { k: Type3ClassMethod(self, k, v) for k, v in methods.items() @@ -52,6 +84,14 @@ class Type3Class: def __repr__(self) -> str: return self.name +Eq = Type3Class('Eq', ['a'], methods={}, operators={ + '==': 'a -> a -> bool', +}) + +Integral = Type3Class('Eq', ['a'], methods={ + 'div': 'a -> a -> a', +}, operators={}) + Num = Type3Class('Num', ['a'], methods={}, operators={ '+': 'a -> a -> a', '-': 'a -> a -> a', diff --git a/phasm/type3/types.py b/phasm/type3/types.py index c6573d4..225e8d8 100644 --- a/phasm/type3/types.py +++ b/phasm/type3/types.py @@ -6,7 +6,7 @@ constraint generator works with. """ from typing import Any, Dict, Iterable, List, Optional, Protocol, Union -from .typeclasses import Num, Type3Class +from .typeclasses import Eq, Num, Type3Class TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before you can call any other method' @@ -238,30 +238,32 @@ The none type, for when functions simply don't return anything. e.g., IO(). bool_ = PrimitiveType3('bool', []) """ The bool type, either True or False + +Suffixes with an underscores, as it's a Python builtin """ -u8 = PrimitiveType3('u8', []) +u8 = PrimitiveType3('u8', [Eq]) """ The unsigned 8-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^8. """ -u32 = PrimitiveType3('u32', [Num]) +u32 = PrimitiveType3('u32', [Eq, Num]) """ The unsigned 32-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^32. """ -u64 = PrimitiveType3('u64', [Num]) +u64 = PrimitiveType3('u64', [Eq, Num]) """ The unsigned 64-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^64. """ -i8 = PrimitiveType3('i8', []) +i8 = PrimitiveType3('i8', [Eq]) """ The signed 8-bit integer type. @@ -269,7 +271,7 @@ Operations on variables employ modular arithmetic, with modulus 2^8, but with the middel point being 0. """ -i32 = PrimitiveType3('i32', [Num]) +i32 = PrimitiveType3('i32', [Eq, Num]) """ The unsigned 32-bit integer type. @@ -277,7 +279,7 @@ Operations on variables employ modular arithmetic, with modulus 2^32, but with the middel point being 0. """ -i64 = PrimitiveType3('i64', [Num]) +i64 = PrimitiveType3('i64', [Eq, Num]) """ The unsigned 64-bit integer type. @@ -285,12 +287,12 @@ Operations on variables employ modular arithmetic, with modulus 2^64, but with the middel point being 0. """ -f32 = PrimitiveType3('f32', [Num]) +f32 = PrimitiveType3('f32', [Eq, Num]) """ A 32-bits IEEE 754 float, of 32 bits width. """ -f64 = PrimitiveType3('f64', [Num]) +f64 = PrimitiveType3('f64', [Eq, Num]) """ A 32-bits IEEE 754 float, of 64 bits width. """ diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index db4574f..92d9401 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -252,6 +252,10 @@ def _load_memory_stored_returned_value( if ret_type3 is type3types.none: return None + if ret_type3 is type3types.bool_: + assert isinstance(wasm_value, int), wasm_value + return 0 != wasm_value + if ret_type3 in (type3types.i32, type3types.i64): assert isinstance(wasm_value, int), wasm_value return wasm_value diff --git a/tests/integration/test_lang/test_eq.py b/tests/integration/test_lang/test_eq.py new file mode 100644 index 0000000..ff1196d --- /dev/null +++ b/tests/integration/test_lang/test_eq.py @@ -0,0 +1,101 @@ +import pytest + +from phasm.type3.entry import Type3Exception + +from ..helpers import Suite + +INT_TYPES = ['u8', 'u32', 'u64', 'i8', 'i32', 'i64'] +FLOAT_TYPES = ['f32', 'f64'] + +TYPE_MAP = { + 'u8': int, + 'u32': int, + 'u64': int, + 'i8': int, + 'i32': int, + 'i64': int, + 'f32': float, + 'f64': float, +} + +@pytest.mark.integration_test +def test_equals_not_implemented(): + code_py = """ +class Foo: + val: i32 + +@exported +def testEntry(x: Foo, y: Foo) -> Foo: + return x + y +""" + + with pytest.raises(Type3Exception, match='Foo does not implement the Num type class'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', INT_TYPES) +def test_equals_int_same(type_): + code_py = f""" +CONSTANT0: {type_} = 10 + +CONSTANT1: {type_} = 10 + +@exported +def testEntry() -> bool: + return CONSTANT0 == CONSTANT1 +""" + + result = Suite(code_py).run_code() + + assert True is result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', INT_TYPES) +def test_equals_int_different(type_): + code_py = f""" +CONSTANT0: {type_} = 10 + +CONSTANT1: {type_} = 3 + +@exported +def testEntry() -> bool: + return CONSTANT0 == CONSTANT1 +""" + + result = Suite(code_py).run_code() + + assert False is result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', FLOAT_TYPES) +def test_equals_float_same(type_): + code_py = f""" +CONSTANT0: {type_} = 10.125 + +CONSTANT1: {type_} = 10.125 + +@exported +def testEntry() -> bool: + return CONSTANT0 == CONSTANT1 +""" + + result = Suite(code_py).run_code() + + assert True is result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', FLOAT_TYPES) +def test_equals_float_different(type_): + code_py = f""" +CONSTANT0: {type_} = 10.32 + +CONSTANT1: {type_} = 10.33 + +@exported +def testEntry() -> bool: + return CONSTANT0 == CONSTANT1 +""" + + result = Suite(code_py).run_code() + + assert False is result.returned_value diff --git a/tests/integration/test_lang/test_integral.py b/tests/integration/test_lang/test_integral.py index 63486da..7bbb373 100644 --- a/tests/integration/test_lang/test_integral.py +++ b/tests/integration/test_lang/test_integral.py @@ -10,7 +10,7 @@ def test_division_int(type_): code_py = f""" @exported def testEntry() -> {type_}: - return 10 / 3 + return div(10, 3) """ result = Suite(code_py).run_code() @@ -24,7 +24,7 @@ def test_division_zero_let_it_crash_int(type_): code_py = f""" @exported def testEntry() -> {type_}: - return 10 / 0 + return div(10, 0) """ # WebAssembly dictates that integer division is a partial operator (e.g. unreachable for 0)