From ffd11c4f7238ca898f8ec99e248a4764a525d25e Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Thu, 16 Nov 2023 15:10:20 +0100 Subject: [PATCH] Started on a type class system --- phasm/codestyle.py | 8 +++- phasm/compiler.py | 46 +++++++++++++++++++- phasm/ourlang.py | 5 ++- phasm/parser.py | 9 ++++ phasm/stdlib/types.py | 18 ++++++++ phasm/type3/constraints.py | 16 ++++--- phasm/type3/constraintsgenerator.py | 32 +++++++++++++- phasm/type3/typeclasses.py | 57 +++++++++++++++++++++++++ phasm/type3/types.py | 47 +++++++++++--------- tests/integration/test_lang/test_num.py | 38 ++++++++++++++--- 10 files changed, 239 insertions(+), 37 deletions(-) create mode 100644 phasm/type3/typeclasses.py diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 62f60b1..2ecc2bd 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -6,6 +6,7 @@ It's intented to be a "any color, as long as it's black" kind of renderer from typing import Generator from . import ourlang +from .type3 import typeclasses as type3classes from .type3 import types as type3types from .type3.types import TYPE3_ASSERTION_ERROR, Type3, Type3OrPlaceholder @@ -103,7 +104,12 @@ def expression(inp: ourlang.Expression) -> str: return f'{inp.operator}{expression(inp.right)}' if isinstance(inp, ourlang.BinaryOp): - return f'{expression(inp.left)} {inp.operator} {expression(inp.right)}' + if isinstance(inp.operator, type3classes.Type3ClassMethod): + operator = inp.operator.name + else: + operator = inp.operator + + return f'{expression(inp.left)} {operator} {expression(inp.right)}' if isinstance(inp, ourlang.FunctionCall): args = ', '.join( diff --git a/phasm/compiler.py b/phasm/compiler.py index 988fbe9..4b0a9fa 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -2,12 +2,13 @@ This module contains the code to convert parsed Ourlang into WebAssembly code """ import struct -from typing import List, Optional +from typing import Dict, List, Optional from . import codestyle, ourlang, wasm from .runtime import calculate_alloc_size, calculate_member_offset from .stdlib import alloc as stdlib_alloc from .stdlib import types as stdlib_types +from .type3 import typeclasses as type3classes from .type3 import types as type3types from .wasmgenerator import Generator as WasmGenerator @@ -25,6 +26,19 @@ LOAD_STORE_TYPE_MAP = { 'bytes': 'i32', # Bytes are passed around as pointers } +# For now this is nice & clean, but this will get messy quick +# Especially once we get functions with polymorphying applied types +INSTANCES = { + type3classes.Num.operators['+']: { + 'a=u32': stdlib_types.u32_num_add, + 'a=u64': stdlib_types.u64_num_add, + 'a=i32': stdlib_types.i32_num_add, + 'a=i64': stdlib_types.i64_num_add, + 'a=f32': stdlib_types.f32_num_add, + 'a=f64': stdlib_types.f64_num_add, + } +} + def phasm_compile(inp: ourlang.Module) -> wasm.Module: """ Public method for compiling a parsed Phasm module into @@ -94,7 +108,6 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType: # Operators that work for i32, i64, f32, f64 OPERATOR_MAP = { - '+': 'add', '-': 'sub', '*': 'mul', '==': 'eq', @@ -303,6 +316,35 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: expression(wgn, inp.right) assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + + if isinstance(inp.operator, type3classes.Type3ClassMethod): + if '=>' in inp.operator.signature: + raise NotImplementedError + + type_var_set = inp.operator.type_vars + + type_var_map: Dict[str, type3types.Type3] = {} + + for arg_letter, arg_expr in zip(inp.operator.signature_parts, [inp.left, inp.right, inp]): + if arg_letter not in type_var_set: + # Fixed type, not part of the lookup requirements + continue + + assert isinstance(arg_expr.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + type_var_map[arg_letter] = arg_expr.type3 + + instance_key = ','.join( + f'{k}={v.name}' + for k, v in type_var_map.items() + ) + + instance = INSTANCES.get(inp.operator, {}).get(instance_key, None) + if instance is not None: + instance(wgn) + return + + raise NotImplementedError(inp.operator, instance_key) + # FIXME: Re-implement build-in operators # Maybe operator_annotation is the way to go # Maybe the older stuff below that is the way to go diff --git a/phasm/ourlang.py b/phasm/ourlang.py index eeb1b1e..3c6dc76 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -6,6 +6,7 @@ from typing import Dict, Iterable, List, Optional, Union from typing_extensions import Final +from .type3 import typeclasses as type3classes from .type3 import types as type3types from .type3.types import PlaceholderForType, StructType3, Type3, Type3OrPlaceholder @@ -149,11 +150,11 @@ class BinaryOp(Expression): """ __slots__ = ('operator', 'left', 'right', ) - operator: str + operator: Union[str, type3classes.Type3ClassMethod] left: Expression right: Expression - def __init__(self, operator: str, left: Expression, right: Expression) -> None: + def __init__(self, operator: Union[str, type3classes.Type3ClassMethod], left: Expression, right: Expression) -> None: super().__init__() self.operator = operator diff --git a/phasm/parser.py b/phasm/parser.py index 426f606..e3a24eb 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -32,8 +32,12 @@ from .ourlang import ( UnaryOp, VariableReference, ) +from .type3 import typeclasses as type3typeclasses from .type3 import types as type3types +PRELUDE_OPERATORS = { + **type3typeclasses.Num.operators, +} def phasm_parse(source: str) -> Module: """ @@ -338,6 +342,8 @@ class OurVisitor: def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression: if isinstance(node, ast.BinOp): + operator: Union[str, type3typeclasses.Type3ClassMethod] + if isinstance(node.op, ast.Add): operator = '+' elif isinstance(node.op, ast.Sub): @@ -359,6 +365,9 @@ class OurVisitor: else: raise NotImplementedError(f'Operator {node.op}') + if operator in PRELUDE_OPERATORS: + operator = PRELUDE_OPERATORS[operator] + return BinaryOp( operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), diff --git a/phasm/stdlib/types.py b/phasm/stdlib/types.py index 5166421..72e5b43 100644 --- a/phasm/stdlib/types.py +++ b/phasm/stdlib/types.py @@ -65,3 +65,21 @@ def __subscript_bytes__(g: Generator, adr: i32, ofs: i32) -> i32: g.return_() return i32('return') # To satisfy mypy + +def u32_num_add(g: Generator) -> None: + g.add_statement('i32.add') + +def u64_num_add(g: Generator) -> None: + g.add_statement('i64.add') + +def i32_num_add(g: Generator) -> None: + g.add_statement('i32.add') + +def i64_num_add(g: Generator) -> None: + g.add_statement('i64.add') + +def f32_num_add(g: Generator) -> None: + g.add_statement('f32.add') + +def f64_num_add(g: Generator) -> None: + g.add_statement('f64.add') diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index e46b376..65061e7 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -6,7 +6,7 @@ These need to be resolved before the program can be compiled. from typing import Dict, List, Optional, Tuple, Union from .. import ourlang -from . import types +from . import typeclasses, types class Error: @@ -288,7 +288,7 @@ class MustImplementTypeClassConstraint(ConstraintBase): """ __slots__ = ('type_class3', 'type3', ) - type_class3: str + type_class3: Union[str, typeclasses.Type3Class] type3: types.Type3OrPlaceholder DATA = { @@ -302,7 +302,7 @@ class MustImplementTypeClassConstraint(ConstraintBase): 'f64': {'BasicMathOperation', 'FloatingPoint'}, } - def __init__(self, type_class3: str, type3: types.Type3OrPlaceholder, comment: Optional[str] = None) -> None: + def __init__(self, type_class3: Union[str, typeclasses.Type3Class], type3: types.Type3OrPlaceholder, comment: Optional[str] = None) -> None: super().__init__(comment=comment) self.type_class3 = type_class3 @@ -316,8 +316,12 @@ class MustImplementTypeClassConstraint(ConstraintBase): if isinstance(typ, types.PlaceholderForType): return RequireTypeSubstitutes() - if self.type_class3 in self.__class__.DATA.get(typ.name, set()): - return None + if isinstance(self.type_class3, typeclasses.Type3Class): + if self.type_class3 in typ.classes: + return None + else: + if self.type_class3 in self.__class__.DATA.get(typ.name, set()): + return None return Error(f'{typ.name} does not implement the {self.type_class3} type class') @@ -325,7 +329,7 @@ class MustImplementTypeClassConstraint(ConstraintBase): return ( '{type3} derives {type_class3}', { - 'type_class3': self.type_class3, + 'type_class3': str(self.type_class3), 'type3': self.type3, }, ) diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index d0db253..e21f504 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -6,6 +6,7 @@ The constraints solver can then try to resolve all constraints. from typing import Generator, List from .. import ourlang +from . import typeclasses as type3typeclasses from . import types as type3types from .constraints import ( CanBeSubscriptedConstraint, @@ -65,6 +66,35 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator: raise NotImplementedError(expression, inp, inp.operator) if isinstance(inp, ourlang.BinaryOp): + if isinstance(inp.operator, type3typeclasses.Type3ClassMethod): + if '=>' in inp.operator.signature: + raise NotImplementedError + + type_var_map = { + x: type3types.PlaceholderForType([]) + for x in inp.operator.type_vars + } + + yield from expression(ctx, inp.left) + yield from expression(ctx, inp.right) + + for arg_letter in inp.operator.type3_class.args: + assert arg_letter in type_var_map # When can this happen? + + yield MustImplementTypeClassConstraint( + inp.operator.type3_class, + type_var_map[arg_letter], + ) + + for arg_letter, arg_expr in zip(inp.operator.signature_parts, [inp.left, inp.right, inp]): + if arg_letter not in type_var_map: + raise NotImplementedError + + yield SameTypeConstraint(type_var_map[arg_letter], arg_expr.type3) + + return + + if inp.operator in ('|', '&', '^', ): yield from expression(ctx, inp.left) yield from expression(ctx, inp.right) @@ -83,7 +113,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator: comment=f'({inp.operator}) :: a -> a -> a') return - if inp.operator in ('+', '-', '*', '/', ): + if inp.operator in ('-', '*', '/', ): yield from expression(ctx, inp.left) yield from expression(ctx, inp.right) diff --git a/phasm/type3/typeclasses.py b/phasm/type3/typeclasses.py new file mode 100644 index 0000000..56f1d3d --- /dev/null +++ b/phasm/type3/typeclasses.py @@ -0,0 +1,57 @@ +from typing import Dict, Iterable, List, Mapping, Set + + +class Type3ClassMethod: + __slots__ = ('type3_class', 'name', 'signature', ) + + type3_class: 'Type3Class' + name: str + signature: str + + def __init__(self, type3_class: 'Type3Class', name: str, signature: str) -> None: + self.type3_class = type3_class + self.name = name + self.signature = signature + + @property + def signature_parts(self) -> List[str]: + return self.signature.split(' -> ') + + @property + def type_vars(self) -> Set[str]: + return { + x + for x in self.signature_parts + if 1 == len(x) + and x == x.lower() + } + + def __repr__(self) -> str: + return f'Type3ClassMethod({repr(self.type3_class)}, {repr(self.name)}, {repr(self.signature)})' + +class Type3Class: + __slots__ = ('name', 'args', 'methods', 'operators', ) + + name: str + args: List[str] + methods: Dict[str, Type3ClassMethod] + operators: Dict[str, Type3ClassMethod] + + def __init__(self, name: str, args: Iterable[str], methods: Mapping[str, str], operators: Mapping[str, str]) -> None: + self.name = name + self.args = [*args] + self.methods = { + k: Type3ClassMethod(self, k, v) + for k, v in methods.items() + } + self.operators = { + k: Type3ClassMethod(self, k, v) + for k, v in operators.items() + } + + def __repr__(self) -> str: + return self.name + +Num = Type3Class('Num', ['a'], methods={}, operators={ + '+': 'a -> a -> a', +}) diff --git a/phasm/type3/types.py b/phasm/type3/types.py index bd10abf..c6573d4 100644 --- a/phasm/type3/types.py +++ b/phasm/type3/types.py @@ -6,6 +6,8 @@ constraint generator works with. """ from typing import Any, Dict, Iterable, List, Optional, Protocol, Union +from .typeclasses import Num, Type3Class + TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before you can call any other method' class ExpressionProtocol(Protocol): @@ -22,18 +24,24 @@ class Type3: """ Base class for the type3 types """ - __slots__ = ('name', ) + __slots__ = ('name', 'classes', ) name: str """ The name of the string, as parsed and outputted by codestyle. """ - def __init__(self, name: str) -> None: + classes: List[Type3Class] + """ + The type classes that this type implements + """ + + def __init__(self, name: str, classes: Iterable[Type3Class]) -> None: self.name = name + self.classes = [*classes] def __repr__(self) -> str: - return f'Type3("{self.name}")' + return f'Type3({repr(self.name)}, {repr(self.classes)})' def __str__(self) -> str: return self.name @@ -79,7 +87,7 @@ class IntType3(Type3): value: int def __init__(self, value: int) -> None: - super().__init__(str(value)) + super().__init__(str(value), []) assert 0 <= value self.value = value @@ -164,7 +172,8 @@ class AppliedType3(Type3): base.name + ' (' + ') ('.join(str(x) for x in args) # FIXME: Do we need to redo the name on substitution? - + ')' + + ')', + [] ) self.base = base @@ -213,7 +222,7 @@ class StructType3(Type3): """ def __init__(self, name: str, members: Dict[str, Type3]) -> None: - super().__init__(name) + super().__init__(name, []) self.name = name self.members = dict(members) @@ -221,38 +230,38 @@ class StructType3(Type3): def __repr__(self) -> str: return f'StructType3(repr({self.name}), repr({self.members}))' -none = PrimitiveType3('none') +none = PrimitiveType3('none', []) """ The none type, for when functions simply don't return anything. e.g., IO(). """ -bool_ = PrimitiveType3('bool') +bool_ = PrimitiveType3('bool', []) """ The bool type, either True or False """ -u8 = PrimitiveType3('u8') +u8 = PrimitiveType3('u8', []) """ The unsigned 8-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^8. """ -u32 = PrimitiveType3('u32') +u32 = PrimitiveType3('u32', [Num]) """ The unsigned 32-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^32. """ -u64 = PrimitiveType3('u64') +u64 = PrimitiveType3('u64', [Num]) """ The unsigned 64-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^64. """ -i8 = PrimitiveType3('i8') +i8 = PrimitiveType3('i8', []) """ The signed 8-bit integer type. @@ -260,7 +269,7 @@ Operations on variables employ modular arithmetic, with modulus 2^8, but with the middel point being 0. """ -i32 = PrimitiveType3('i32') +i32 = PrimitiveType3('i32', [Num]) """ The unsigned 32-bit integer type. @@ -268,7 +277,7 @@ Operations on variables employ modular arithmetic, with modulus 2^32, but with the middel point being 0. """ -i64 = PrimitiveType3('i64') +i64 = PrimitiveType3('i64', [Num]) """ The unsigned 64-bit integer type. @@ -276,22 +285,22 @@ Operations on variables employ modular arithmetic, with modulus 2^64, but with the middel point being 0. """ -f32 = PrimitiveType3('f32') +f32 = PrimitiveType3('f32', [Num]) """ A 32-bits IEEE 754 float, of 32 bits width. """ -f64 = PrimitiveType3('f64') +f64 = PrimitiveType3('f64', [Num]) """ A 32-bits IEEE 754 float, of 64 bits width. """ -bytes = PrimitiveType3('bytes') +bytes = PrimitiveType3('bytes', []) """ This is a runtime-determined length piece of memory that can be indexed at runtime. """ -static_array = PrimitiveType3('static_array') +static_array = PrimitiveType3('static_array', []) """ This is a fixed length piece of memory that can be indexed at runtime. @@ -299,7 +308,7 @@ It should be applied with one argument. It has a runtime-dynamic length of the same type repeated. """ -tuple = PrimitiveType3('tuple') # pylint: disable=W0622 +tuple = PrimitiveType3('tuple', []) # pylint: disable=W0622 """ This is a fixed length piece of memory. diff --git a/tests/integration/test_lang/test_num.py b/tests/integration/test_lang/test_num.py index 7cd0547..70ec473 100644 --- a/tests/integration/test_lang/test_num.py +++ b/tests/integration/test_lang/test_num.py @@ -1,5 +1,7 @@ import pytest +from phasm.type3.entry import Type3Exception + from ..helpers import Suite INT_TYPES = ['u32', 'u64', 'i32', 'i64'] @@ -14,6 +16,20 @@ TYPE_MAP = { 'f64': float, } +@pytest.mark.integration_test +def test_addition_not_implemented(): + code_py = """ +class Foo: + val: i32 + +@exported +def testEntry(x: Foo, y: Foo) -> Foo: + return x + y +""" + + with pytest.raises(Type3Exception, match='Foo does not implement the Num type class'): + Suite(code_py).run_code() + @pytest.mark.integration_test @pytest.mark.parametrize('type_', INT_TYPES) def test_addition_int(type_): @@ -71,18 +87,28 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] is type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.skip('TODO: Runtimes return a signed value, which is difficult to test') -@pytest.mark.parametrize('type_', ('u32', 'u64')) # FIXME: u8 -def test_subtraction_underflow(type_): - code_py = f""" +def test_subtraction_negative_result(): + code_py = """ @exported -def testEntry() -> {type_}: +def testEntry() -> i32: return 10 - 11 """ result = Suite(code_py).run_code() - assert 0 < result.returned_value + assert -1 == result.returned_value + +@pytest.mark.integration_test +def test_subtraction_underflow(): + code_py = """ +@exported +def testEntry() -> u32: + return 10 - 11 +""" + + result = Suite(code_py).run_code() + + assert 4294967295 == result.returned_value # TODO: Multiplication