Started on a type class system

This commit is contained in:
Johan B.W. de Vries 2023-11-16 15:10:20 +01:00
parent 5d9ef0e276
commit ffd11c4f72
10 changed files with 239 additions and 37 deletions

View File

@ -6,6 +6,7 @@ It's intented to be a "any color, as long as it's black" kind of renderer
from typing import Generator from typing import Generator
from . import ourlang from . import ourlang
from .type3 import typeclasses as type3classes
from .type3 import types as type3types from .type3 import types as type3types
from .type3.types import TYPE3_ASSERTION_ERROR, Type3, Type3OrPlaceholder from .type3.types import TYPE3_ASSERTION_ERROR, Type3, Type3OrPlaceholder
@ -103,7 +104,12 @@ def expression(inp: ourlang.Expression) -> str:
return f'{inp.operator}{expression(inp.right)}' return f'{inp.operator}{expression(inp.right)}'
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
return f'{expression(inp.left)} {inp.operator} {expression(inp.right)}' if isinstance(inp.operator, type3classes.Type3ClassMethod):
operator = inp.operator.name
else:
operator = inp.operator
return f'{expression(inp.left)} {operator} {expression(inp.right)}'
if isinstance(inp, ourlang.FunctionCall): if isinstance(inp, ourlang.FunctionCall):
args = ', '.join( args = ', '.join(

View File

@ -2,12 +2,13 @@
This module contains the code to convert parsed Ourlang into WebAssembly code This module contains the code to convert parsed Ourlang into WebAssembly code
""" """
import struct import struct
from typing import List, Optional from typing import Dict, List, Optional
from . import codestyle, ourlang, wasm from . import codestyle, ourlang, wasm
from .runtime import calculate_alloc_size, calculate_member_offset from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
from .type3 import typeclasses as type3classes
from .type3 import types as type3types from .type3 import types as type3types
from .wasmgenerator import Generator as WasmGenerator from .wasmgenerator import Generator as WasmGenerator
@ -25,6 +26,19 @@ LOAD_STORE_TYPE_MAP = {
'bytes': 'i32', # Bytes are passed around as pointers 'bytes': 'i32', # Bytes are passed around as pointers
} }
# For now this is nice & clean, but this will get messy quick
# Especially once we get functions with polymorphying applied types
INSTANCES = {
type3classes.Num.operators['+']: {
'a=u32': stdlib_types.u32_num_add,
'a=u64': stdlib_types.u64_num_add,
'a=i32': stdlib_types.i32_num_add,
'a=i64': stdlib_types.i64_num_add,
'a=f32': stdlib_types.f32_num_add,
'a=f64': stdlib_types.f64_num_add,
}
}
def phasm_compile(inp: ourlang.Module) -> wasm.Module: def phasm_compile(inp: ourlang.Module) -> wasm.Module:
""" """
Public method for compiling a parsed Phasm module into Public method for compiling a parsed Phasm module into
@ -94,7 +108,6 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType:
# Operators that work for i32, i64, f32, f64 # Operators that work for i32, i64, f32, f64
OPERATOR_MAP = { OPERATOR_MAP = {
'+': 'add',
'-': 'sub', '-': 'sub',
'*': 'mul', '*': 'mul',
'==': 'eq', '==': 'eq',
@ -303,6 +316,35 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
expression(wgn, inp.right) expression(wgn, inp.right)
assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
if isinstance(inp.operator, type3classes.Type3ClassMethod):
if '=>' in inp.operator.signature:
raise NotImplementedError
type_var_set = inp.operator.type_vars
type_var_map: Dict[str, type3types.Type3] = {}
for arg_letter, arg_expr in zip(inp.operator.signature_parts, [inp.left, inp.right, inp]):
if arg_letter not in type_var_set:
# Fixed type, not part of the lookup requirements
continue
assert isinstance(arg_expr.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
type_var_map[arg_letter] = arg_expr.type3
instance_key = ','.join(
f'{k}={v.name}'
for k, v in type_var_map.items()
)
instance = INSTANCES.get(inp.operator, {}).get(instance_key, None)
if instance is not None:
instance(wgn)
return
raise NotImplementedError(inp.operator, instance_key)
# FIXME: Re-implement build-in operators # FIXME: Re-implement build-in operators
# Maybe operator_annotation is the way to go # Maybe operator_annotation is the way to go
# Maybe the older stuff below that is the way to go # Maybe the older stuff below that is the way to go

View File

@ -6,6 +6,7 @@ from typing import Dict, Iterable, List, Optional, Union
from typing_extensions import Final from typing_extensions import Final
from .type3 import typeclasses as type3classes
from .type3 import types as type3types from .type3 import types as type3types
from .type3.types import PlaceholderForType, StructType3, Type3, Type3OrPlaceholder from .type3.types import PlaceholderForType, StructType3, Type3, Type3OrPlaceholder
@ -149,11 +150,11 @@ class BinaryOp(Expression):
""" """
__slots__ = ('operator', 'left', 'right', ) __slots__ = ('operator', 'left', 'right', )
operator: str operator: Union[str, type3classes.Type3ClassMethod]
left: Expression left: Expression
right: Expression right: Expression
def __init__(self, operator: str, left: Expression, right: Expression) -> None: def __init__(self, operator: Union[str, type3classes.Type3ClassMethod], left: Expression, right: Expression) -> None:
super().__init__() super().__init__()
self.operator = operator self.operator = operator

View File

@ -32,8 +32,12 @@ from .ourlang import (
UnaryOp, UnaryOp,
VariableReference, VariableReference,
) )
from .type3 import typeclasses as type3typeclasses
from .type3 import types as type3types from .type3 import types as type3types
PRELUDE_OPERATORS = {
**type3typeclasses.Num.operators,
}
def phasm_parse(source: str) -> Module: def phasm_parse(source: str) -> Module:
""" """
@ -338,6 +342,8 @@ class OurVisitor:
def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression: def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression:
if isinstance(node, ast.BinOp): if isinstance(node, ast.BinOp):
operator: Union[str, type3typeclasses.Type3ClassMethod]
if isinstance(node.op, ast.Add): if isinstance(node.op, ast.Add):
operator = '+' operator = '+'
elif isinstance(node.op, ast.Sub): elif isinstance(node.op, ast.Sub):
@ -359,6 +365,9 @@ class OurVisitor:
else: else:
raise NotImplementedError(f'Operator {node.op}') raise NotImplementedError(f'Operator {node.op}')
if operator in PRELUDE_OPERATORS:
operator = PRELUDE_OPERATORS[operator]
return BinaryOp( return BinaryOp(
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),

View File

@ -65,3 +65,21 @@ def __subscript_bytes__(g: Generator, adr: i32, ofs: i32) -> i32:
g.return_() g.return_()
return i32('return') # To satisfy mypy return i32('return') # To satisfy mypy
def u32_num_add(g: Generator) -> None:
g.add_statement('i32.add')
def u64_num_add(g: Generator) -> None:
g.add_statement('i64.add')
def i32_num_add(g: Generator) -> None:
g.add_statement('i32.add')
def i64_num_add(g: Generator) -> None:
g.add_statement('i64.add')
def f32_num_add(g: Generator) -> None:
g.add_statement('f32.add')
def f64_num_add(g: Generator) -> None:
g.add_statement('f64.add')

View File

@ -6,7 +6,7 @@ These need to be resolved before the program can be compiled.
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
from .. import ourlang from .. import ourlang
from . import types from . import typeclasses, types
class Error: class Error:
@ -288,7 +288,7 @@ class MustImplementTypeClassConstraint(ConstraintBase):
""" """
__slots__ = ('type_class3', 'type3', ) __slots__ = ('type_class3', 'type3', )
type_class3: str type_class3: Union[str, typeclasses.Type3Class]
type3: types.Type3OrPlaceholder type3: types.Type3OrPlaceholder
DATA = { DATA = {
@ -302,7 +302,7 @@ class MustImplementTypeClassConstraint(ConstraintBase):
'f64': {'BasicMathOperation', 'FloatingPoint'}, 'f64': {'BasicMathOperation', 'FloatingPoint'},
} }
def __init__(self, type_class3: str, type3: types.Type3OrPlaceholder, comment: Optional[str] = None) -> None: def __init__(self, type_class3: Union[str, typeclasses.Type3Class], type3: types.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
self.type_class3 = type_class3 self.type_class3 = type_class3
@ -316,8 +316,12 @@ class MustImplementTypeClassConstraint(ConstraintBase):
if isinstance(typ, types.PlaceholderForType): if isinstance(typ, types.PlaceholderForType):
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
if self.type_class3 in self.__class__.DATA.get(typ.name, set()): if isinstance(self.type_class3, typeclasses.Type3Class):
return None if self.type_class3 in typ.classes:
return None
else:
if self.type_class3 in self.__class__.DATA.get(typ.name, set()):
return None
return Error(f'{typ.name} does not implement the {self.type_class3} type class') return Error(f'{typ.name} does not implement the {self.type_class3} type class')
@ -325,7 +329,7 @@ class MustImplementTypeClassConstraint(ConstraintBase):
return ( return (
'{type3} derives {type_class3}', '{type3} derives {type_class3}',
{ {
'type_class3': self.type_class3, 'type_class3': str(self.type_class3),
'type3': self.type3, 'type3': self.type3,
}, },
) )

View File

@ -6,6 +6,7 @@ The constraints solver can then try to resolve all constraints.
from typing import Generator, List from typing import Generator, List
from .. import ourlang from .. import ourlang
from . import typeclasses as type3typeclasses
from . import types as type3types from . import types as type3types
from .constraints import ( from .constraints import (
CanBeSubscriptedConstraint, CanBeSubscriptedConstraint,
@ -65,6 +66,35 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator:
raise NotImplementedError(expression, inp, inp.operator) raise NotImplementedError(expression, inp, inp.operator)
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
if isinstance(inp.operator, type3typeclasses.Type3ClassMethod):
if '=>' in inp.operator.signature:
raise NotImplementedError
type_var_map = {
x: type3types.PlaceholderForType([])
for x in inp.operator.type_vars
}
yield from expression(ctx, inp.left)
yield from expression(ctx, inp.right)
for arg_letter in inp.operator.type3_class.args:
assert arg_letter in type_var_map # When can this happen?
yield MustImplementTypeClassConstraint(
inp.operator.type3_class,
type_var_map[arg_letter],
)
for arg_letter, arg_expr in zip(inp.operator.signature_parts, [inp.left, inp.right, inp]):
if arg_letter not in type_var_map:
raise NotImplementedError
yield SameTypeConstraint(type_var_map[arg_letter], arg_expr.type3)
return
if inp.operator in ('|', '&', '^', ): if inp.operator in ('|', '&', '^', ):
yield from expression(ctx, inp.left) yield from expression(ctx, inp.left)
yield from expression(ctx, inp.right) yield from expression(ctx, inp.right)
@ -83,7 +113,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator:
comment=f'({inp.operator}) :: a -> a -> a') comment=f'({inp.operator}) :: a -> a -> a')
return return
if inp.operator in ('+', '-', '*', '/', ): if inp.operator in ('-', '*', '/', ):
yield from expression(ctx, inp.left) yield from expression(ctx, inp.left)
yield from expression(ctx, inp.right) yield from expression(ctx, inp.right)

View File

@ -0,0 +1,57 @@
from typing import Dict, Iterable, List, Mapping, Set
class Type3ClassMethod:
__slots__ = ('type3_class', 'name', 'signature', )
type3_class: 'Type3Class'
name: str
signature: str
def __init__(self, type3_class: 'Type3Class', name: str, signature: str) -> None:
self.type3_class = type3_class
self.name = name
self.signature = signature
@property
def signature_parts(self) -> List[str]:
return self.signature.split(' -> ')
@property
def type_vars(self) -> Set[str]:
return {
x
for x in self.signature_parts
if 1 == len(x)
and x == x.lower()
}
def __repr__(self) -> str:
return f'Type3ClassMethod({repr(self.type3_class)}, {repr(self.name)}, {repr(self.signature)})'
class Type3Class:
__slots__ = ('name', 'args', 'methods', 'operators', )
name: str
args: List[str]
methods: Dict[str, Type3ClassMethod]
operators: Dict[str, Type3ClassMethod]
def __init__(self, name: str, args: Iterable[str], methods: Mapping[str, str], operators: Mapping[str, str]) -> None:
self.name = name
self.args = [*args]
self.methods = {
k: Type3ClassMethod(self, k, v)
for k, v in methods.items()
}
self.operators = {
k: Type3ClassMethod(self, k, v)
for k, v in operators.items()
}
def __repr__(self) -> str:
return self.name
Num = Type3Class('Num', ['a'], methods={}, operators={
'+': 'a -> a -> a',
})

View File

@ -6,6 +6,8 @@ constraint generator works with.
""" """
from typing import Any, Dict, Iterable, List, Optional, Protocol, Union from typing import Any, Dict, Iterable, List, Optional, Protocol, Union
from .typeclasses import Num, Type3Class
TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before you can call any other method' TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before you can call any other method'
class ExpressionProtocol(Protocol): class ExpressionProtocol(Protocol):
@ -22,18 +24,24 @@ class Type3:
""" """
Base class for the type3 types Base class for the type3 types
""" """
__slots__ = ('name', ) __slots__ = ('name', 'classes', )
name: str name: str
""" """
The name of the string, as parsed and outputted by codestyle. The name of the string, as parsed and outputted by codestyle.
""" """
def __init__(self, name: str) -> None: classes: List[Type3Class]
"""
The type classes that this type implements
"""
def __init__(self, name: str, classes: Iterable[Type3Class]) -> None:
self.name = name self.name = name
self.classes = [*classes]
def __repr__(self) -> str: def __repr__(self) -> str:
return f'Type3("{self.name}")' return f'Type3({repr(self.name)}, {repr(self.classes)})'
def __str__(self) -> str: def __str__(self) -> str:
return self.name return self.name
@ -79,7 +87,7 @@ class IntType3(Type3):
value: int value: int
def __init__(self, value: int) -> None: def __init__(self, value: int) -> None:
super().__init__(str(value)) super().__init__(str(value), [])
assert 0 <= value assert 0 <= value
self.value = value self.value = value
@ -164,7 +172,8 @@ class AppliedType3(Type3):
base.name base.name
+ ' (' + ' ('
+ ') ('.join(str(x) for x in args) # FIXME: Do we need to redo the name on substitution? + ') ('.join(str(x) for x in args) # FIXME: Do we need to redo the name on substitution?
+ ')' + ')',
[]
) )
self.base = base self.base = base
@ -213,7 +222,7 @@ class StructType3(Type3):
""" """
def __init__(self, name: str, members: Dict[str, Type3]) -> None: def __init__(self, name: str, members: Dict[str, Type3]) -> None:
super().__init__(name) super().__init__(name, [])
self.name = name self.name = name
self.members = dict(members) self.members = dict(members)
@ -221,38 +230,38 @@ class StructType3(Type3):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'StructType3(repr({self.name}), repr({self.members}))' return f'StructType3(repr({self.name}), repr({self.members}))'
none = PrimitiveType3('none') none = PrimitiveType3('none', [])
""" """
The none type, for when functions simply don't return anything. e.g., IO(). The none type, for when functions simply don't return anything. e.g., IO().
""" """
bool_ = PrimitiveType3('bool') bool_ = PrimitiveType3('bool', [])
""" """
The bool type, either True or False The bool type, either True or False
""" """
u8 = PrimitiveType3('u8') u8 = PrimitiveType3('u8', [])
""" """
The unsigned 8-bit integer type. The unsigned 8-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^8. Operations on variables employ modular arithmetic, with modulus 2^8.
""" """
u32 = PrimitiveType3('u32') u32 = PrimitiveType3('u32', [Num])
""" """
The unsigned 32-bit integer type. The unsigned 32-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^32. Operations on variables employ modular arithmetic, with modulus 2^32.
""" """
u64 = PrimitiveType3('u64') u64 = PrimitiveType3('u64', [Num])
""" """
The unsigned 64-bit integer type. The unsigned 64-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^64. Operations on variables employ modular arithmetic, with modulus 2^64.
""" """
i8 = PrimitiveType3('i8') i8 = PrimitiveType3('i8', [])
""" """
The signed 8-bit integer type. The signed 8-bit integer type.
@ -260,7 +269,7 @@ Operations on variables employ modular arithmetic, with modulus 2^8, but
with the middel point being 0. with the middel point being 0.
""" """
i32 = PrimitiveType3('i32') i32 = PrimitiveType3('i32', [Num])
""" """
The unsigned 32-bit integer type. The unsigned 32-bit integer type.
@ -268,7 +277,7 @@ Operations on variables employ modular arithmetic, with modulus 2^32, but
with the middel point being 0. with the middel point being 0.
""" """
i64 = PrimitiveType3('i64') i64 = PrimitiveType3('i64', [Num])
""" """
The unsigned 64-bit integer type. The unsigned 64-bit integer type.
@ -276,22 +285,22 @@ Operations on variables employ modular arithmetic, with modulus 2^64, but
with the middel point being 0. with the middel point being 0.
""" """
f32 = PrimitiveType3('f32') f32 = PrimitiveType3('f32', [Num])
""" """
A 32-bits IEEE 754 float, of 32 bits width. A 32-bits IEEE 754 float, of 32 bits width.
""" """
f64 = PrimitiveType3('f64') f64 = PrimitiveType3('f64', [Num])
""" """
A 32-bits IEEE 754 float, of 64 bits width. A 32-bits IEEE 754 float, of 64 bits width.
""" """
bytes = PrimitiveType3('bytes') bytes = PrimitiveType3('bytes', [])
""" """
This is a runtime-determined length piece of memory that can be indexed at runtime. This is a runtime-determined length piece of memory that can be indexed at runtime.
""" """
static_array = PrimitiveType3('static_array') static_array = PrimitiveType3('static_array', [])
""" """
This is a fixed length piece of memory that can be indexed at runtime. This is a fixed length piece of memory that can be indexed at runtime.
@ -299,7 +308,7 @@ It should be applied with one argument. It has a runtime-dynamic length
of the same type repeated. of the same type repeated.
""" """
tuple = PrimitiveType3('tuple') # pylint: disable=W0622 tuple = PrimitiveType3('tuple', []) # pylint: disable=W0622
""" """
This is a fixed length piece of memory. This is a fixed length piece of memory.

View File

@ -1,5 +1,7 @@
import pytest import pytest
from phasm.type3.entry import Type3Exception
from ..helpers import Suite from ..helpers import Suite
INT_TYPES = ['u32', 'u64', 'i32', 'i64'] INT_TYPES = ['u32', 'u64', 'i32', 'i64']
@ -14,6 +16,20 @@ TYPE_MAP = {
'f64': float, 'f64': float,
} }
@pytest.mark.integration_test
def test_addition_not_implemented():
code_py = """
class Foo:
val: i32
@exported
def testEntry(x: Foo, y: Foo) -> Foo:
return x + y
"""
with pytest.raises(Type3Exception, match='Foo does not implement the Num type class'):
Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', INT_TYPES) @pytest.mark.parametrize('type_', INT_TYPES)
def test_addition_int(type_): def test_addition_int(type_):
@ -71,18 +87,28 @@ def testEntry() -> {type_}:
assert TYPE_MAP[type_] is type(result.returned_value) assert TYPE_MAP[type_] is type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.skip('TODO: Runtimes return a signed value, which is difficult to test') def test_subtraction_negative_result():
@pytest.mark.parametrize('type_', ('u32', 'u64')) # FIXME: u8 code_py = """
def test_subtraction_underflow(type_):
code_py = f"""
@exported @exported
def testEntry() -> {type_}: def testEntry() -> i32:
return 10 - 11 return 10 - 11
""" """
result = Suite(code_py).run_code() result = Suite(code_py).run_code()
assert 0 < result.returned_value assert -1 == result.returned_value
@pytest.mark.integration_test
def test_subtraction_underflow():
code_py = """
@exported
def testEntry() -> u32:
return 10 - 11
"""
result = Suite(code_py).run_code()
assert 4294967295 == result.returned_value
# TODO: Multiplication # TODO: Multiplication