Cleanup, got rid of OPERATOR_MAP

This commit is contained in:
Johan B.W. de Vries 2023-11-16 15:48:53 +01:00
parent 19a29b7327
commit 4001b086db
9 changed files with 239 additions and 76 deletions

View File

@ -29,6 +29,16 @@ LOAD_STORE_TYPE_MAP = {
# For now this is nice & clean, but this will get messy quick # For now this is nice & clean, but this will get messy quick
# Especially once we get functions with polymorphying applied types # Especially once we get functions with polymorphying applied types
INSTANCES = { INSTANCES = {
type3classes.Eq.operators['==']: {
'a=u8': stdlib_types.u8_eq_equals,
'a=u32': stdlib_types.u32_eq_equals,
'a=u64': stdlib_types.u64_eq_equals,
'a=i8': stdlib_types.i8_eq_equals,
'a=i32': stdlib_types.i32_eq_equals,
'a=i64': stdlib_types.i64_eq_equals,
'a=f32': stdlib_types.f32_eq_equals,
'a=f64': stdlib_types.f64_eq_equals,
},
type3classes.Num.operators['+']: { type3classes.Num.operators['+']: {
'a=u32': stdlib_types.u32_num_add, 'a=u32': stdlib_types.u32_num_add,
'a=u64': stdlib_types.u64_num_add, 'a=u64': stdlib_types.u64_num_add,
@ -74,6 +84,11 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType:
if inp == type3types.none: if inp == type3types.none:
return wasm.WasmTypeNone() return wasm.WasmTypeNone()
if inp == type3types.bool_:
# WebAssembly stores booleans as i32
# See e.g. f32.eq, which is [f32 f32] -> [i32]
return wasm.WasmTypeInt32()
if inp == type3types.u8: if inp == type3types.u8:
# WebAssembly has only support for 32 and 64 bits # WebAssembly has only support for 32 and 64 bits
# So we need to store more memory per byte # So we need to store more memory per byte
@ -122,11 +137,6 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType:
raise NotImplementedError(type3, inp) raise NotImplementedError(type3, inp)
# Operators that work for i32, i64, f32, f64
OPERATOR_MAP = {
'==': 'eq',
}
U8_OPERATOR_MAP = { U8_OPERATOR_MAP = {
# Under the hood, this is an i32 # Under the hood, this is an i32
# Implementing Right Shift XOR, OR, AND is fine since the 3 remaining # Implementing Right Shift XOR, OR, AND is fine since the 3 remaining
@ -332,23 +342,18 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
if isinstance(inp.operator, type3classes.Type3ClassMethod): if isinstance(inp.operator, type3classes.Type3ClassMethod):
if '=>' in inp.operator.signature: type_var_map: Dict[type3classes.TypeVariable, type3types.Type3] = {}
raise NotImplementedError
type_var_set = inp.operator.type_vars for type_var, arg_expr in zip(inp.operator.signature, [inp.left, inp.right, inp]):
if not isinstance(type_var, type3classes.TypeVariable):
type_var_map: Dict[str, type3types.Type3] = {}
for arg_letter, arg_expr in zip(inp.operator.signature_parts, [inp.left, inp.right, inp]):
if arg_letter not in type_var_set:
# Fixed type, not part of the lookup requirements # Fixed type, not part of the lookup requirements
continue continue
assert isinstance(arg_expr.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR assert isinstance(arg_expr.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR
type_var_map[arg_letter] = arg_expr.type3 type_var_map[type_var] = arg_expr.type3
instance_key = ','.join( instance_key = ','.join(
f'{k}={v.name}' f'{k.letter}={v.name}'
for k, v in type_var_map.items() for k, v in type_var_map.items()
) )
@ -371,53 +376,32 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if operator_annotation == '(<) :: u64 -> u64 -> bool': if operator_annotation == '(<) :: u64 -> u64 -> bool':
wgn.add_statement('i64.lt_u') wgn.add_statement('i64.lt_u')
return return
if operator_annotation == '(==) :: u64 -> u64 -> bool':
wgn.add_statement('i64.eq')
return
if inp.type3 == type3types.u8: if inp.type3 == type3types.u8:
if operator := U8_OPERATOR_MAP.get(inp.operator, None): if operator := U8_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if inp.type3 == type3types.u32: if inp.type3 == type3types.u32:
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}')
return
if operator := U32_OPERATOR_MAP.get(inp.operator, None): if operator := U32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if inp.type3 == type3types.u64: if inp.type3 == type3types.u64:
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}')
return
if operator := U64_OPERATOR_MAP.get(inp.operator, None): if operator := U64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if inp.type3 == type3types.i32: if inp.type3 == type3types.i32:
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}')
return
if operator := I32_OPERATOR_MAP.get(inp.operator, None): if operator := I32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i32.{operator}') wgn.add_statement(f'i32.{operator}')
return return
if inp.type3 == type3types.i64: if inp.type3 == type3types.i64:
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}')
return
if operator := I64_OPERATOR_MAP.get(inp.operator, None): if operator := I64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'i64.{operator}') wgn.add_statement(f'i64.{operator}')
return return
if inp.type3 == type3types.f32: if inp.type3 == type3types.f32:
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f32.{operator}')
return
if operator := F32_OPERATOR_MAP.get(inp.operator, None): if operator := F32_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f32.{operator}') wgn.add_statement(f'f32.{operator}')
return return
if inp.type3 == type3types.f64: if inp.type3 == type3types.f64:
if operator := OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f64.{operator}')
return
if operator := F64_OPERATOR_MAP.get(inp.operator, None): if operator := F64_OPERATOR_MAP.get(inp.operator, None):
wgn.add_statement(f'f64.{operator}') wgn.add_statement(f'f64.{operator}')
return return

View File

@ -36,6 +36,7 @@ from .type3 import typeclasses as type3typeclasses
from .type3 import types as type3types from .type3 import types as type3types
PRELUDE_OPERATORS = { PRELUDE_OPERATORS = {
**type3typeclasses.Eq.operators,
**type3typeclasses.Num.operators, **type3typeclasses.Num.operators,
} }
@ -400,6 +401,9 @@ class OurVisitor:
else: else:
raise NotImplementedError(f'Operator {node.ops}') raise NotImplementedError(f'Operator {node.ops}')
if operator in PRELUDE_OPERATORS:
operator = PRELUDE_OPERATORS[operator]
return BinaryOp( return BinaryOp(
operator, operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),

View File

@ -66,6 +66,30 @@ def __subscript_bytes__(g: Generator, adr: i32, ofs: i32) -> i32:
return i32('return') # To satisfy mypy return i32('return') # To satisfy mypy
def u8_eq_equals(g: Generator) -> None:
g.add_statement('i32.eq')
def u32_eq_equals(g: Generator) -> None:
g.add_statement('i32.eq')
def u64_eq_equals(g: Generator) -> None:
g.add_statement('i64.eq')
def i8_eq_equals(g: Generator) -> None:
g.add_statement('i32.eq')
def i32_eq_equals(g: Generator) -> None:
g.add_statement('i32.eq')
def i64_eq_equals(g: Generator) -> None:
g.add_statement('i64.eq')
def f32_eq_equals(g: Generator) -> None:
g.add_statement('f32.eq')
def f64_eq_equals(g: Generator) -> None:
g.add_statement('f64.eq')
def u32_num_add(g: Generator) -> None: def u32_num_add(g: Generator) -> None:
g.add_statement('i32.add') g.add_statement('i32.add')

View File

@ -67,30 +67,34 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator:
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
if isinstance(inp.operator, type3typeclasses.Type3ClassMethod): if isinstance(inp.operator, type3typeclasses.Type3ClassMethod):
if '=>' in inp.operator.signature:
raise NotImplementedError
type_var_map = { type_var_map = {
x: type3types.PlaceholderForType([]) x: type3types.PlaceholderForType([])
for x in inp.operator.type_vars for x in inp.operator.signature
if isinstance(x, type3typeclasses.TypeVariable)
} }
yield from expression(ctx, inp.left) yield from expression(ctx, inp.left)
yield from expression(ctx, inp.right) yield from expression(ctx, inp.right)
for arg_letter in inp.operator.type3_class.args: for type_var in inp.operator.type3_class.args:
assert arg_letter in type_var_map # When can this happen? assert type_var in type_var_map # When can this happen?
yield MustImplementTypeClassConstraint( yield MustImplementTypeClassConstraint(
inp.operator.type3_class, inp.operator.type3_class,
type_var_map[arg_letter], type_var_map[type_var],
) )
for arg_letter, arg_expr in zip(inp.operator.signature_parts, [inp.left, inp.right, inp]): for sig_part, arg_expr in zip(inp.operator.signature, [inp.left, inp.right, inp]):
if arg_letter not in type_var_map: if isinstance(sig_part, type3typeclasses.TypeVariable):
raise NotImplementedError yield SameTypeConstraint(type_var_map[sig_part], arg_expr.type3)
continue
yield SameTypeConstraint(type_var_map[arg_letter], arg_expr.type3) if isinstance(sig_part, type3typeclasses.TypeReference):
# On key error: We probably have to a lot of work to do refactoring
# the type lookups
exp_type = type3types.LOOKUP_TABLE[sig_part.name]
yield SameTypeConstraint(exp_type, arg_expr.type3)
continue
return return

View File

@ -1,30 +1,62 @@
from typing import Dict, Iterable, List, Mapping, Set from typing import Any, Dict, Iterable, List, Mapping, Union
class TypeVariable:
__slots__ = ('letter', )
letter: str
def __init__(self, letter: str) -> None:
assert len(letter) == 1, f'{letter} is not a valid type variable'
self.letter = letter
def __hash__(self) -> int:
return hash(self.letter)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeVariable):
raise NotImplementedError
return self.letter == other.letter
def __repr__(self) -> str:
return f'TypeVariable({repr(self.letter)})'
class TypeReference:
__slots__ = ('name', )
name: str
def __init__(self, name: str) -> None:
assert len(name) > 1, f'{name} is not a valid type reference'
self.name = name
def __hash__(self) -> int:
return hash(self.name)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeReference):
raise NotImplementedError
return self.name == other.name
def __repr__(self) -> str:
return f'TypeReference({repr(self.name)})'
class Type3ClassMethod: class Type3ClassMethod:
__slots__ = ('type3_class', 'name', 'signature', ) __slots__ = ('type3_class', 'name', 'signature', )
type3_class: 'Type3Class' type3_class: 'Type3Class'
name: str name: str
signature: str signature: List[Union[TypeReference, TypeVariable]]
def __init__(self, type3_class: 'Type3Class', name: str, signature: str) -> None: def __init__(self, type3_class: 'Type3Class', name: str, signature: str) -> None:
self.type3_class = type3_class self.type3_class = type3_class
self.name = name self.name = name
self.signature = signature self.signature = [
TypeVariable(x) if len(x) == 1 else TypeReference(x)
@property for x in signature.split(' -> ')
def signature_parts(self) -> List[str]: ]
return self.signature.split(' -> ')
@property
def type_vars(self) -> Set[str]:
return {
x
for x in self.signature_parts
if 1 == len(x)
and x == x.lower()
}
def __repr__(self) -> str: def __repr__(self) -> str:
return f'Type3ClassMethod({repr(self.type3_class)}, {repr(self.name)}, {repr(self.signature)})' return f'Type3ClassMethod({repr(self.type3_class)}, {repr(self.name)}, {repr(self.signature)})'
@ -33,13 +65,13 @@ class Type3Class:
__slots__ = ('name', 'args', 'methods', 'operators', ) __slots__ = ('name', 'args', 'methods', 'operators', )
name: str name: str
args: List[str] args: List[TypeVariable]
methods: Dict[str, Type3ClassMethod] methods: Dict[str, Type3ClassMethod]
operators: Dict[str, Type3ClassMethod] operators: Dict[str, Type3ClassMethod]
def __init__(self, name: str, args: Iterable[str], methods: Mapping[str, str], operators: Mapping[str, str]) -> None: def __init__(self, name: str, args: Iterable[str], methods: Mapping[str, str], operators: Mapping[str, str]) -> None:
self.name = name self.name = name
self.args = [*args] self.args = [TypeVariable(x) for x in args]
self.methods = { self.methods = {
k: Type3ClassMethod(self, k, v) k: Type3ClassMethod(self, k, v)
for k, v in methods.items() for k, v in methods.items()
@ -52,6 +84,14 @@ class Type3Class:
def __repr__(self) -> str: def __repr__(self) -> str:
return self.name return self.name
Eq = Type3Class('Eq', ['a'], methods={}, operators={
'==': 'a -> a -> bool',
})
Integral = Type3Class('Eq', ['a'], methods={
'div': 'a -> a -> a',
}, operators={})
Num = Type3Class('Num', ['a'], methods={}, operators={ Num = Type3Class('Num', ['a'], methods={}, operators={
'+': 'a -> a -> a', '+': 'a -> a -> a',
'-': 'a -> a -> a', '-': 'a -> a -> a',

View File

@ -6,7 +6,7 @@ constraint generator works with.
""" """
from typing import Any, Dict, Iterable, List, Optional, Protocol, Union from typing import Any, Dict, Iterable, List, Optional, Protocol, Union
from .typeclasses import Num, Type3Class from .typeclasses import Eq, Num, Type3Class
TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before you can call any other method' TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before you can call any other method'
@ -238,30 +238,32 @@ The none type, for when functions simply don't return anything. e.g., IO().
bool_ = PrimitiveType3('bool', []) bool_ = PrimitiveType3('bool', [])
""" """
The bool type, either True or False The bool type, either True or False
Suffixes with an underscores, as it's a Python builtin
""" """
u8 = PrimitiveType3('u8', []) u8 = PrimitiveType3('u8', [Eq])
""" """
The unsigned 8-bit integer type. The unsigned 8-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^8. Operations on variables employ modular arithmetic, with modulus 2^8.
""" """
u32 = PrimitiveType3('u32', [Num]) u32 = PrimitiveType3('u32', [Eq, Num])
""" """
The unsigned 32-bit integer type. The unsigned 32-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^32. Operations on variables employ modular arithmetic, with modulus 2^32.
""" """
u64 = PrimitiveType3('u64', [Num]) u64 = PrimitiveType3('u64', [Eq, Num])
""" """
The unsigned 64-bit integer type. The unsigned 64-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^64. Operations on variables employ modular arithmetic, with modulus 2^64.
""" """
i8 = PrimitiveType3('i8', []) i8 = PrimitiveType3('i8', [Eq])
""" """
The signed 8-bit integer type. The signed 8-bit integer type.
@ -269,7 +271,7 @@ Operations on variables employ modular arithmetic, with modulus 2^8, but
with the middel point being 0. with the middel point being 0.
""" """
i32 = PrimitiveType3('i32', [Num]) i32 = PrimitiveType3('i32', [Eq, Num])
""" """
The unsigned 32-bit integer type. The unsigned 32-bit integer type.
@ -277,7 +279,7 @@ Operations on variables employ modular arithmetic, with modulus 2^32, but
with the middel point being 0. with the middel point being 0.
""" """
i64 = PrimitiveType3('i64', [Num]) i64 = PrimitiveType3('i64', [Eq, Num])
""" """
The unsigned 64-bit integer type. The unsigned 64-bit integer type.
@ -285,12 +287,12 @@ Operations on variables employ modular arithmetic, with modulus 2^64, but
with the middel point being 0. with the middel point being 0.
""" """
f32 = PrimitiveType3('f32', [Num]) f32 = PrimitiveType3('f32', [Eq, Num])
""" """
A 32-bits IEEE 754 float, of 32 bits width. A 32-bits IEEE 754 float, of 32 bits width.
""" """
f64 = PrimitiveType3('f64', [Num]) f64 = PrimitiveType3('f64', [Eq, Num])
""" """
A 32-bits IEEE 754 float, of 64 bits width. A 32-bits IEEE 754 float, of 64 bits width.
""" """

View File

@ -249,6 +249,10 @@ def _load_memory_stored_returned_value(
if ret_type3 is type3types.none: if ret_type3 is type3types.none:
return None return None
if ret_type3 is type3types.bool_:
assert isinstance(wasm_value, int), wasm_value
return 0 != wasm_value
if ret_type3 in (type3types.i32, type3types.i64): if ret_type3 in (type3types.i32, type3types.i64):
assert isinstance(wasm_value, int), wasm_value assert isinstance(wasm_value, int), wasm_value
return wasm_value return wasm_value

View File

@ -0,0 +1,101 @@
import pytest
from phasm.type3.entry import Type3Exception
from ..helpers import Suite
INT_TYPES = ['u8', 'u32', 'u64', 'i8', 'i32', 'i64']
FLOAT_TYPES = ['f32', 'f64']
TYPE_MAP = {
'u8': int,
'u32': int,
'u64': int,
'i8': int,
'i32': int,
'i64': int,
'f32': float,
'f64': float,
}
@pytest.mark.integration_test
def test_equals_not_implemented():
code_py = """
class Foo:
val: i32
@exported
def testEntry(x: Foo, y: Foo) -> Foo:
return x + y
"""
with pytest.raises(Type3Exception, match='Foo does not implement the Num type class'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', INT_TYPES)
def test_equals_int_same(type_):
code_py = f"""
CONSTANT0: {type_} = 10
CONSTANT1: {type_} = 10
@exported
def testEntry() -> bool:
return CONSTANT0 == CONSTANT1
"""
result = Suite(code_py).run_code()
assert True is result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', INT_TYPES)
def test_equals_int_different(type_):
code_py = f"""
CONSTANT0: {type_} = 10
CONSTANT1: {type_} = 3
@exported
def testEntry() -> bool:
return CONSTANT0 == CONSTANT1
"""
result = Suite(code_py).run_code()
assert False is result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', FLOAT_TYPES)
def test_equals_float_same(type_):
code_py = f"""
CONSTANT0: {type_} = 10.125
CONSTANT1: {type_} = 10.125
@exported
def testEntry() -> bool:
return CONSTANT0 == CONSTANT1
"""
result = Suite(code_py).run_code()
assert True is result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', FLOAT_TYPES)
def test_equals_float_different(type_):
code_py = f"""
CONSTANT0: {type_} = 10.32
CONSTANT1: {type_} = 10.33
@exported
def testEntry() -> bool:
return CONSTANT0 == CONSTANT1
"""
result = Suite(code_py).run_code()
assert False is result.returned_value

View File

@ -10,7 +10,7 @@ def test_division_int(type_):
code_py = f""" code_py = f"""
@exported @exported
def testEntry() -> {type_}: def testEntry() -> {type_}:
return 10 / 3 return div(10, 3)
""" """
result = Suite(code_py).run_code() result = Suite(code_py).run_code()
@ -24,7 +24,7 @@ def test_division_zero_let_it_crash_int(type_):
code_py = f""" code_py = f"""
@exported @exported
def testEntry() -> {type_}: def testEntry() -> {type_}:
return 10 / 0 return div(10, 0)
""" """
# WebAssembly dictates that integer division is a partial operator (e.g. unreachable for 0) # WebAssembly dictates that integer division is a partial operator (e.g. unreachable for 0)