diff --git a/TODO.md b/TODO.md index fd09f34..49940cf 100644 --- a/TODO.md +++ b/TODO.md @@ -29,7 +29,7 @@ - Casting is not implemented except u32 which is special cased - Parser is putting stuff in ModuleDataBlock - Compiler should probably do that -- ourlang.BinaryOp should probably always be a Type3ClassMethod - - Remove U32_OPERATOR_MAP / U64_OPERATOR_MAP - Make prelude more an actual thing - Implemented Bounded: https://hackage.haskell.org/package/base-4.21.0.0/docs/Prelude.html#t:Bounded +- Try to implement the min and max functions using select +- Filter out methods that aren't used; other the other way around (easier?) only add __ methods when needed diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 2b2c24c..cd1804a 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -6,7 +6,6 @@ It's intented to be a "any color, as long as it's black" kind of renderer from typing import Generator from . import ourlang -from .type3 import typeclasses as type3classes from .type3 import types as type3types from .type3.types import TYPE3_ASSERTION_ERROR, Type3, Type3OrPlaceholder @@ -103,12 +102,7 @@ def expression(inp: ourlang.Expression) -> str: return f'{inp.operator}{expression(inp.right)}' if isinstance(inp, ourlang.BinaryOp): - if isinstance(inp.operator, type3classes.Type3ClassMethod): - operator = inp.operator.name - else: - operator = inp.operator - - return f'{expression(inp.left)} {operator} {expression(inp.right)}' + return f'{expression(inp.left)} {inp.operator.name} {expression(inp.right)}' if isinstance(inp, ourlang.FunctionCall): args = ', '.join( diff --git a/phasm/compiler.py b/phasm/compiler.py index 2205941..8ae5ab7 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -109,6 +109,41 @@ INSTANCES = { 'a=f32': stdlib_types.f32_ord_greater_than_or_equal, 'a=f64': stdlib_types.f64_ord_greater_than_or_equal, }, + type3classes.Bits.methods['shl']: { + 'a=u8': stdlib_types.u8_bits_logical_shift_left, + 'a=u32': stdlib_types.u32_bits_logical_shift_left, + 'a=u64': stdlib_types.u64_bits_logical_shift_left, + }, + type3classes.Bits.methods['shr']: { + 'a=u8': stdlib_types.u8_bits_logical_shift_right, + 'a=u32': stdlib_types.u32_bits_logical_shift_right, + 'a=u64': stdlib_types.u64_bits_logical_shift_right, + }, + type3classes.Bits.methods['rotl']: { + 'a=u8': stdlib_types.u8_bits_rotate_left, + 'a=u32': stdlib_types.u32_bits_rotate_left, + 'a=u64': stdlib_types.u64_bits_rotate_left, + }, + type3classes.Bits.methods['rotr']: { + 'a=u8': stdlib_types.u8_bits_rotate_right, + 'a=u32': stdlib_types.u32_bits_rotate_right, + 'a=u64': stdlib_types.u64_bits_rotate_right, + }, + type3classes.Bits.operators['&']: { + 'a=u8': stdlib_types.u8_bits_bitwise_and, + 'a=u32': stdlib_types.u32_bits_bitwise_and, + 'a=u64': stdlib_types.u64_bits_bitwise_and, + }, + type3classes.Bits.operators['|']: { + 'a=u8': stdlib_types.u8_bits_bitwise_or, + 'a=u32': stdlib_types.u32_bits_bitwise_or, + 'a=u64': stdlib_types.u64_bits_bitwise_or, + }, + type3classes.Bits.operators['^']: { + 'a=u8': stdlib_types.u8_bits_bitwise_xor, + 'a=u32': stdlib_types.u32_bits_bitwise_xor, + 'a=u64': stdlib_types.u64_bits_bitwise_xor, + }, type3classes.Floating.methods['sqrt']: { 'a=f32': stdlib_types.f32_floating_sqrt, 'a=f64': stdlib_types.f64_floating_sqrt, @@ -271,30 +306,6 @@ def type3(inp: type3types.Type3OrPlaceholder) -> wasm.WasmType: raise NotImplementedError(type3, inp) -U8_OPERATOR_MAP = { - '^': 'xor', - '|': 'or', - '&': 'and', -} - -U32_OPERATOR_MAP = { - '^': 'xor', - '|': 'or', - '&': 'and', -} - -U64_OPERATOR_MAP = { - '^': 'xor', - '|': 'or', - '&': 'and', -} - -I32_OPERATOR_MAP: dict[str, str] = { -} - -I64_OPERATOR_MAP: dict[str, str] = { -} - def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> None: """ Compile: Instantiation (allocation) of a tuple @@ -439,64 +450,27 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: assert isinstance(inp.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - if isinstance(inp.operator, type3classes.Type3ClassMethod): - type_var_map: Dict[type3classes.TypeVariable, type3types.Type3] = {} + type_var_map: Dict[type3classes.TypeVariable, type3types.Type3] = {} - for type_var, arg_expr in zip(inp.operator.signature, [inp.left, inp.right, inp]): - if not isinstance(type_var, type3classes.TypeVariable): - # Fixed type, not part of the lookup requirements - continue + for type_var, arg_expr in zip(inp.operator.signature, [inp.left, inp.right, inp]): + if not isinstance(type_var, type3classes.TypeVariable): + # Fixed type, not part of the lookup requirements + continue - assert isinstance(arg_expr.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR - type_var_map[type_var] = arg_expr.type3 + assert isinstance(arg_expr.type3, type3types.Type3), type3types.TYPE3_ASSERTION_ERROR + type_var_map[type_var] = arg_expr.type3 - instance_key = ','.join( - f'{k.letter}={v.name}' - for k, v in type_var_map.items() - ) + instance_key = ','.join( + f'{k.letter}={v.name}' + for k, v in type_var_map.items() + ) - instance = INSTANCES.get(inp.operator, {}).get(instance_key, None) - if instance is not None: - instance(wgn) - return - - raise NotImplementedError(inp.operator, instance_key) - - # FIXME: Re-implement build-in operators - # Maybe operator_annotation is the way to go - # Maybe the older stuff below that is the way to go - - operator_annotation = f'({inp.operator}) :: {inp.left.type3:s} -> {inp.right.type3:s} -> {inp.type3:s}' - if operator_annotation == '(>) :: i32 -> i32 -> bool': - wgn.add_statement('i32.gt_s') + instance = INSTANCES.get(inp.operator, {}).get(instance_key, None) + if instance is not None: + instance(wgn) return - if operator_annotation == '(<) :: u64 -> u64 -> bool': - wgn.add_statement('i64.lt_u') - return - - if inp.type3 == type3types.u8: - if operator := U8_OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i32.{operator}') - return - if inp.type3 == type3types.u32: - if operator := U32_OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i32.{operator}') - return - if inp.type3 == type3types.u64: - if operator := U64_OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i64.{operator}') - return - if inp.type3 == type3types.i32: - if operator := I32_OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i32.{operator}') - return - if inp.type3 == type3types.i64: - if operator := I64_OPERATOR_MAP.get(inp.operator, None): - wgn.add_statement(f'i64.{operator}') - return - - raise NotImplementedError(expression, inp.operator, inp.left.type3, inp.right.type3, inp.type3) + raise NotImplementedError(inp.operator, instance_key) if isinstance(inp, ourlang.UnaryOp): expression(wgn, inp.right) @@ -992,6 +966,8 @@ def module(inp: ourlang.Module) -> wasm.Module: stdlib_types.__i32_intnum_abs__, stdlib_types.__i64_intnum_abs__, stdlib_types.__u32_pow2__, + stdlib_types.__u8_rotl__, + stdlib_types.__u8_rotr__, ] + [ function(x) for x in inp.functions.values() diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 657aa1d..407c277 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -149,11 +149,11 @@ class BinaryOp(Expression): """ __slots__ = ('operator', 'left', 'right', ) - operator: Union[str, type3typeclasses.Type3ClassMethod] + operator: type3typeclasses.Type3ClassMethod left: Expression right: Expression - def __init__(self, operator: Union[str, type3typeclasses.Type3ClassMethod], left: Expression, right: Expression) -> None: + def __init__(self, operator: type3typeclasses.Type3ClassMethod, left: Expression, right: Expression) -> None: super().__init__() self.operator = operator diff --git a/phasm/parser.py b/phasm/parser.py index 2e44f92..eb8eef9 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -35,6 +35,7 @@ from .type3 import typeclasses as type3typeclasses from .type3 import types as type3types PRELUDE_OPERATORS = { + **type3typeclasses.Bits.operators, **type3typeclasses.Eq.operators, **type3typeclasses.Ord.operators, **type3typeclasses.Fractional.operators, @@ -44,6 +45,7 @@ PRELUDE_OPERATORS = { } PRELUDE_METHODS = { + **type3typeclasses.Bits.methods, **type3typeclasses.Eq.methods, **type3typeclasses.Ord.methods, **type3typeclasses.Floating.methods, @@ -383,11 +385,11 @@ class OurVisitor: else: raise NotImplementedError(f'Operator {node.op}') - if operator in PRELUDE_OPERATORS: - operator = PRELUDE_OPERATORS[operator] + if operator not in PRELUDE_OPERATORS: + raise NotImplementedError(f'Operator {operator}') return BinaryOp( - operator, + PRELUDE_OPERATORS[operator], self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right), ) @@ -424,11 +426,11 @@ class OurVisitor: else: raise NotImplementedError(f'Operator {node.ops}') - if operator in PRELUDE_OPERATORS: - operator = PRELUDE_OPERATORS[operator] + if operator not in PRELUDE_OPERATORS: + raise NotImplementedError(f'Operator {operator}') return BinaryOp( - operator, + PRELUDE_OPERATORS[operator], self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]), ) diff --git a/phasm/stdlib/types.py b/phasm/stdlib/types.py index 837fa43..70c1709 100644 --- a/phasm/stdlib/types.py +++ b/phasm/stdlib/types.py @@ -257,6 +257,135 @@ def __u32_pow2__(g: Generator, x: i32) -> i32: g.i32.shl() return i32('return') # To satisfy mypy + +@func_wrapper() +def __u8_rotl__(g: Generator, x: i32, r: i32) -> i32: + s = i32('s') # The shifted part we need to overlay + + # Handle cases where we need to shift more than 8 bits + g.local.get(r) + g.i32.const(8) + g.i32.rem_u() + g.local.set(r) + + # Now do the rotation + + g.local.get(x) + + # 0000 0000 1100 0011 + + g.local.get(r) + + # 0000 0000 1100 0011, 3 + + g.i32.shl() + + # 0000 0110 0001 1000 + + g.local.tee(s) + + # 0000 0110 0001 1000 + + g.i32.const(255) + + # 0000 0110 0001 1000, 0000 0000 1111 1111 + + g.i32.and_() + + # 0000 0000 0001 1000 + + g.local.get(s) + + # 0000 0000 0001 1000, 0000 0110 0001 1000 + + g.i32.const(65280) + + # 0000 0000 0001 1000, 0000 0110 0001 1000, 1111 1111 0000 0000 + + g.i32.and_() + + # 0000 0000 0001 1000, 0000 0110 0000 0000 + + g.i32.const(8) + + # 0000 0000 0001 1000, 0000 0110 0000 0000, 8 + + g.i32.shr_u() + + # 0000 0000 0001 1000, 0000 0000 0000 0110 + + g.i32.or_() + + # 0000 0000 0001 110 + + g.return_() + + return i32('return') # To satisfy mypy + +@func_wrapper() +def __u8_rotr__(g: Generator, x: i32, r: i32) -> i32: + s = i32('s') # The shifted part we need to overlay + + # Handle cases where we need to shift more than 8 bits + g.local.get(r) + g.i32.const(8) + g.i32.rem_u() + g.local.set(r) + + # Now do the rotation + + g.local.get(x) + + # 0000 0000 1100 0011 + + g.local.get(r) + + # 0000 0000 1100 0011, 3 + + g.i32.rotr() + + # 0110 0000 0000 0000 0000 0000 0001 1000 + + g.local.tee(s) + + # 0110 0000 0000 0000 0000 0000 0001 1000 + + g.i32.const(255) + + # 0110 0000 0000 0000 0000 0000 0001 1000, 0000 0000 1111 1111 + + g.i32.and_() + + # 0000 0000 0000 0000 0000 0000 0001 1000 + + g.local.get(s) + + # 0000 0000 0000 0000 0000 0000 0001 1000, 0110 0000 0000 0000 0000 0000 0001 1000 + + g.i32.const(4278190080) + + # 0000 0000 0000 0000 0000 0000 0001 1000, 0110 0000 0000 0000 0000 0000 0001 1000, 1111 1111 0000 0000 0000 0000 0000 0000 + + g.i32.and_() + + # 0000 0000 0000 0000 0000 0000 0001 1000, 0110 0000 0000 0000 0000 0000 0000 0000 + + g.i32.const(24) + + # 0000 0000 0000 0000 0000 0000 0001 1000, 0110 0000 0000 0000 0000 0000 0000 0000, 24 + + g.i32.shr_u() + + # 0000 0000 0000 0000 0000 0000 0001 1000, 0000 0000 0000 0000 0000 0000 0110 0000 + + g.i32.or_() + + # 0000 0000 0000 0000 0000 0000 0111 1000 + + g.return_() + + return i32('return') # To satisfy mypy + ## ### ## class Eq @@ -456,6 +585,78 @@ def f32_ord_greater_than_or_equal(g: Generator) -> None: def f64_ord_greater_than_or_equal(g: Generator) -> None: g.f64.ge() +## ### +## class Bits + +def u8_bits_logical_shift_left(g: Generator) -> None: + g.i32.shl() + g.i32.const(255) + g.i32.and_() + +def u32_bits_logical_shift_left(g: Generator) -> None: + g.i32.shl() + +def u64_bits_logical_shift_left(g: Generator) -> None: + g.i64.extend_i32_u() + g.i64.shl() + +def u8_bits_logical_shift_right(g: Generator) -> None: + g.i32.shr_u() + +def u32_bits_logical_shift_right(g: Generator) -> None: + g.i32.shr_u() + +def u64_bits_logical_shift_right(g: Generator) -> None: + g.i64.extend_i32_u() + g.i64.shr_u() + +def u8_bits_rotate_left(g: Generator) -> None: + g.add_statement('call $stdlib.types.__u8_rotl__') + +def u32_bits_rotate_left(g: Generator) -> None: + g.i32.rotl() + +def u64_bits_rotate_left(g: Generator) -> None: + g.i64.extend_i32_u() + g.i64.rotl() + +def u8_bits_rotate_right(g: Generator) -> None: + g.add_statement('call $stdlib.types.__u8_rotr__') + +def u32_bits_rotate_right(g: Generator) -> None: + g.i32.rotr() + +def u64_bits_rotate_right(g: Generator) -> None: + g.i64.extend_i32_u() + g.i64.rotr() + +def u8_bits_bitwise_and(g: Generator) -> None: + g.i32.and_() + +def u32_bits_bitwise_and(g: Generator) -> None: + g.i32.and_() + +def u64_bits_bitwise_and(g: Generator) -> None: + g.i64.and_() + +def u8_bits_bitwise_or(g: Generator) -> None: + g.i32.or_() + +def u32_bits_bitwise_or(g: Generator) -> None: + g.i32.or_() + +def u64_bits_bitwise_or(g: Generator) -> None: + g.i64.or_() + +def u8_bits_bitwise_xor(g: Generator) -> None: + g.i32.xor() + +def u32_bits_bitwise_xor(g: Generator) -> None: + g.i32.xor() + +def u64_bits_bitwise_xor(g: Generator) -> None: + g.i64.xor() + ## ### ## class Fractional diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index 76c3997..3b9e6de 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -60,78 +60,35 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator: raise NotImplementedError(expression, inp, inp.operator) if isinstance(inp, ourlang.BinaryOp): - if isinstance(inp.operator, type3typeclasses.Type3ClassMethod): - type_var_map = { - x: type3types.PlaceholderForType([]) - for x in inp.operator.signature - if isinstance(x, type3typeclasses.TypeVariable) - } + type_var_map = { + x: type3types.PlaceholderForType([]) + for x in inp.operator.signature + if isinstance(x, type3typeclasses.TypeVariable) + } - yield from expression(ctx, inp.left) - yield from expression(ctx, inp.right) + yield from expression(ctx, inp.left) + yield from expression(ctx, inp.right) - for type_var in inp.operator.type3_class.args: - assert type_var in type_var_map # When can this happen? + for type_var in inp.operator.type3_class.args: + assert type_var in type_var_map # When can this happen? - yield MustImplementTypeClassConstraint( - inp.operator.type3_class, - type_var_map[type_var], - ) + yield MustImplementTypeClassConstraint( + inp.operator.type3_class, + type_var_map[type_var], + ) - for sig_part, arg_expr in zip(inp.operator.signature, [inp.left, inp.right, inp]): - if isinstance(sig_part, type3typeclasses.TypeVariable): - yield SameTypeConstraint(type_var_map[sig_part], arg_expr.type3) - continue + for sig_part, arg_expr in zip(inp.operator.signature, [inp.left, inp.right, inp]): + if isinstance(sig_part, type3typeclasses.TypeVariable): + yield SameTypeConstraint(type_var_map[sig_part], arg_expr.type3) + continue - if isinstance(sig_part, type3typeclasses.TypeReference): - # On key error: We probably have to a lot of work to do refactoring - # the type lookups - exp_type = type3types.LOOKUP_TABLE[sig_part.name] - yield SameTypeConstraint(exp_type, arg_expr.type3) - continue - return - - if inp.operator in ('|', '&', '^', ): - yield from expression(ctx, inp.left) - yield from expression(ctx, inp.right) - - yield MustImplementTypeClassConstraint('BitWiseOperation', inp.left.type3) - yield SameTypeConstraint(inp.left.type3, inp.right.type3, inp.type3, - comment=f'({inp.operator}) :: a -> a -> a') - return - - if inp.operator in ('>>', '<<', ): - yield from expression(ctx, inp.left) - yield from expression(ctx, inp.right) - - yield MustImplementTypeClassConstraint('BitWiseOperation', inp.left.type3) - yield SameTypeConstraint(inp.left.type3, inp.right.type3, inp.type3, - comment=f'({inp.operator}) :: a -> a -> a') - return - - if inp.operator == '==': - yield from expression(ctx, inp.left) - yield from expression(ctx, inp.right) - - yield MustImplementTypeClassConstraint('EqualComparison', inp.left.type3) - yield SameTypeConstraint(inp.left.type3, inp.right.type3, - comment=f'({inp.operator}) :: a -> a -> bool') - yield SameTypeConstraint(inp.type3, type3types.bool_, - comment=f'({inp.operator}) :: a -> a -> bool') - return - - if inp.operator in ('<', '>'): - yield from expression(ctx, inp.left) - yield from expression(ctx, inp.right) - - yield MustImplementTypeClassConstraint('StrictPartialOrder', inp.left.type3) - yield SameTypeConstraint(inp.left.type3, inp.right.type3, - comment=f'({inp.operator}) :: a -> a -> bool') - yield SameTypeConstraint(inp.type3, type3types.bool_, - comment=f'({inp.operator}) :: a -> a -> bool') - return - - raise NotImplementedError(expression, inp) + if isinstance(sig_part, type3typeclasses.TypeReference): + # On key error: We probably have to a lot of work to do refactoring + # the type lookups + exp_type = type3types.LOOKUP_TABLE[sig_part.name] + yield SameTypeConstraint(exp_type, arg_expr.type3) + continue + return if isinstance(inp, ourlang.FunctionCall): if isinstance(inp.function, type3typeclasses.Type3ClassMethod): diff --git a/phasm/type3/typeclasses.py b/phasm/type3/typeclasses.py index 519d884..a412179 100644 --- a/phasm/type3/typeclasses.py +++ b/phasm/type3/typeclasses.py @@ -96,6 +96,7 @@ class Type3Class: Eq = Type3Class('Eq', ['a'], methods={}, operators={ '==': 'a -> a -> bool', '!=': 'a -> a -> bool', + # FIXME: Do we want to expose 'eqz'? Or is that a compiler optimization? }) Ord = Type3Class('Ord', ['a'], methods={ @@ -108,6 +109,18 @@ Ord = Type3Class('Ord', ['a'], methods={ '>=': 'a -> a -> bool', }, inherited_classes=[Eq]) +Bits = Type3Class('Bits', ['a'], methods={ + 'shl': 'a -> u32 -> a', # Logical shift left + 'shr': 'a -> u32 -> a', # Logical shift right + 'rotl': 'a -> u32 -> a', # Rotate bits left + 'rotr': 'a -> u32 -> a', # Rotate bits right + # FIXME: Do we want to expose clz, ctz, popcnt? +}, operators={ + '&': 'a -> a -> a', # Bit-wise and + '|': 'a -> a -> a', # Bit-wise or + '^': 'a -> a -> a', # Bit-wise xor +}) + NatNum = Type3Class('NatNum', ['a'], methods={}, operators={ '+': 'a -> a -> a', '-': 'a -> a -> a', @@ -139,3 +152,5 @@ Fractional = Type3Class('Fractional', ['a'], methods={ Floating = Type3Class('Floating', ['a'], methods={ 'sqrt': 'a -> a', }, operators={}, inherited_classes=[Fractional]) + +# FIXME: Do we want to expose copysign? diff --git a/phasm/type3/types.py b/phasm/type3/types.py index a5bf68d..6da5729 100644 --- a/phasm/type3/types.py +++ b/phasm/type3/types.py @@ -7,6 +7,7 @@ constraint generator works with. from typing import Any, Dict, Iterable, List, Optional, Protocol, Set, Union from .typeclasses import ( + Bits, Eq, Floating, Fractional, @@ -259,21 +260,21 @@ The bool type, either True or False Suffixes with an underscores, as it's a Python builtin """ -u8 = PrimitiveType3('u8', [Eq, Ord]) +u8 = PrimitiveType3('u8', [Bits, Eq, Ord]) """ The unsigned 8-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^8. """ -u32 = PrimitiveType3('u32', [Eq, Integral, NatNum, Ord]) +u32 = PrimitiveType3('u32', [Bits, Eq, Integral, NatNum, Ord]) """ The unsigned 32-bit integer type. Operations on variables employ modular arithmetic, with modulus 2^32. """ -u64 = PrimitiveType3('u64', [Eq, Integral, NatNum, Ord]) +u64 = PrimitiveType3('u64', [Bits, Eq, Integral, NatNum, Ord]) """ The unsigned 64-bit integer type. diff --git a/phasm/wasmgenerator.py b/phasm/wasmgenerator.py index 31ad488..7bb0c72 100644 --- a/phasm/wasmgenerator.py +++ b/phasm/wasmgenerator.py @@ -257,6 +257,10 @@ def func_wrapper(exported: bool = True) -> Callable[[Any], wasm.Function]: # Check what locals were used, and define them locals_: List[wasm.Param] = [] for local_name, local_type in generator.locals.items(): + if local_name in args: + # Already defined as a local by wasm itself + continue + locals_.append((local_name, local_type.wasm_type(), )) # Complete function definition diff --git a/tests/integration/test_lang/test_bits.py b/tests/integration/test_lang/test_bits.py index e873ee6..cca168d 100644 --- a/tests/integration/test_lang/test_bits.py +++ b/tests/integration/test_lang/test_bits.py @@ -1,115 +1,81 @@ import pytest -from phasm.type3.entry import Type3Exception - from ..helpers import Suite +class ExpResult: + def __init__(self, default, **kwargs): + self.default = default + self.kwargs = kwargs + + def get(self, type_): + return self.kwargs.get(type_, self.default) + + def __repr__(self): + return 'ExpResult(' + repr(self.default) + ', ' + ', '.join( + f'{k}={repr(v)}' + for k, v in self.kwargs.items() + ) + ')' + +TYPE_LIST = [ + 'u8', 'u32', 'u64', +] + @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation -def test_logical_left_shift(type_): +@pytest.mark.parametrize('type_', TYPE_LIST) +@pytest.mark.parametrize('test_lft,test_opr,test_rgt,test_out', [ + (9, '&', 1, 1), + (9, '&', 2, 0), + (9, '&', 8, 8), + + (9, '|', 2, 11), + (9, '|', 3, 11), + + (9, '^', 2, 11), + (9, '^', 3, 10), +]) +def test_bits_operators(type_, test_lft, test_opr, test_rgt, test_out): code_py = f""" @exported -def testEntry() -> {type_}: - return 10 << 3 +def testEntry(lft: {type_}, rgt: {type_}) -> {type_}: + return lft {test_opr} rgt """ + result = Suite(code_py).run_code(test_lft, test_rgt) - result = Suite(code_py).run_code() - - assert 80 == result.returned_value - assert isinstance(result.returned_value, int) + assert test_out == result.returned_value @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u32', 'u64']) -def test_logical_right_shift_left_bit_zero(type_): +@pytest.mark.parametrize('type_', TYPE_LIST) +@pytest.mark.parametrize('test_mtd,test_lft,test_rgt,test_out', [ + # 195 = 1100 0011 + # shl(195, 3) = 1560 = 0001 1000 + # shl(195, 3) = 1560 = 0000 0110 0001 1000 + # shr(195, 3) = 24 = 1 1000 + # rotl(195, 3) = 30 = 0001 1110 + # rotl(195, 3) = 1560 = 0000 0110 0001 1000 + # rotr(195, 3) = 120 = 0111 1000 + # rotr(195, 3) = = 0110 0000 0001 1000 + # rotr(195, 3) = = 0110 0000 0000 0000 0000 0000 0001 1000 + # rotr(195, 3) = = 0110 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0000 0001 1000 + + ('shl', 6, 1, ExpResult(12)), + ('shl', 195, 3, ExpResult(1560, u8=24)), + + ('shr', 6, 1, ExpResult(3)), + ('shr', 195, 3, ExpResult(24)), + + ('rotl', 6, 1, ExpResult(12)), + ('rotl', 195, 3, ExpResult(1560, u8=30)), + + ('rotr', 6, 1, ExpResult(3)), + ('rotr', 195, 3, ExpResult(None, u8=120, u16=24600, u32=1610612760, u64=6917529027641081880)), +]) +def test_bits_methods(type_, test_mtd, test_lft, test_rgt, test_out): code_py = f""" @exported -def testEntry() -> {type_}: - return 10 >> 3 +def testEntry(lft: {type_}, rgt: u32) -> {type_}: + return {test_mtd}(lft, rgt) """ + result = Suite(code_py).run_code(test_lft, test_rgt) - # Check with wasmtime, as other engines don't mind if the type - # doesn't match. They'll complain when: (>>) : u32 -> u64 -> u32 - result = Suite(code_py).run_code(runtime='wasmtime') - - assert 1 == result.returned_value - assert isinstance(result.returned_value, int) - -@pytest.mark.integration_test -def test_logical_right_shift_left_bit_one(): - code_py = """ -@exported -def testEntry() -> u32: - return 4294967295 >> 16 -""" - - result = Suite(code_py).run_code() - - assert 0xFFFF == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_or_uint(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 | 3 -""" - - result = Suite(code_py).run_code() - - assert 11 == result.returned_value - assert isinstance(result.returned_value, int) - -@pytest.mark.integration_test -def test_bitwise_or_inv_type(): - code_py = """ -@exported -def testEntry() -> f64: - return 10.0 | 3.0 -""" - - with pytest.raises(Type3Exception, match='f64 does not implement the BitWiseOperation type class'): - Suite(code_py).run_code() - -@pytest.mark.integration_test -def test_bitwise_or_type_mismatch(): - code_py = """ -CONSTANT1: u32 = 3 -CONSTANT2: u64 = 3 - -@exported -def testEntry() -> u64: - return CONSTANT1 | CONSTANT2 -""" - - with pytest.raises(Type3Exception, match='u64 must be u32 instead'): - Suite(code_py).run_code() - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_xor(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 ^ 3 -""" - - result = Suite(code_py).run_code() - - assert 9 == result.returned_value - assert isinstance(result.returned_value, int) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_and(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 & 3 -""" - - result = Suite(code_py).run_code() - - assert 2 == result.returned_value - assert isinstance(result.returned_value, int) + assert test_out.get(type_) == result.returned_value