From 299551db1bd843a3d1b44f421b7117a47f642bd0 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 11:49:10 +0200 Subject: [PATCH] All primitive tests work again --- phasm/compiler.py | 2 - phasm/typer.py | 28 ++++++--- phasm/typing.py | 58 ++++++++++++------- .../integration/test_lang/test_primitives.py | 16 ++++- 4 files changed, 70 insertions(+), 34 deletions(-) diff --git a/phasm/compiler.py b/phasm/compiler.py index 7f6d53f..18a548a 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -72,8 +72,6 @@ U8_OPERATOR_MAP = { # Under the hood, this is an i32 # Implementing Right Shift XOR, OR, AND is fine since the 3 remaining # bytes stay zero after this operation - # Since it's unsigned an unsigned value, Logical or Arithmetic shift right - # are the same operation '>>': 'shr_u', '^': 'xor', '|': 'or', diff --git a/phasm/typer.py b/phasm/typer.py index 3041ec3..2b17f63 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -77,17 +77,29 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return right if isinstance(inp, ourlang.BinaryOp): - # TODO: Simplified version - if inp.operator not in ('+', '-', '*', '|', '&', '^'): - raise NotImplementedError(expression, inp, inp.operator) + if inp.operator in ('+', '-', '*', '|', '&', '^'): + left = expression(ctx, inp.left) + right = expression(ctx, inp.right) + ctx.unify(left, right) - left = expression(ctx, inp.left) - right = expression(ctx, inp.right) - ctx.unify(left, right) + inp.type_var = left + return left - inp.type_var = left + if inp.operator in ('<<', '>>', ): + inp.type_var = ctx.new_var() + inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, ))) + inp.type_var.add_constraint(TypeConstraintSigned(False)) - return left + left = expression(ctx, inp.left) + right = expression(ctx, inp.right) + ctx.unify(left, right) + + ctx.unify(inp.type_var, left) + + return left + + raise NotImplementedError(expression, inp, inp.operator) if isinstance(inp, ourlang.FunctionCall): assert inp.function.returns_type_var is not None diff --git a/phasm/typing.py b/phasm/typing.py index de92a7b..0bebf8d 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -1,7 +1,7 @@ """ The phasm type system """ -from typing import Dict, Optional, List, Type +from typing import Dict, Iterable, Optional, List, Set, Type import enum @@ -217,35 +217,41 @@ class TypeConstraintBitWidth(TypeConstraintBase): """ Contraint on how many bits an expression has or can possibly have """ - __slots__ = ('minb', 'maxb', ) + __slots__ = ('oneof', ) - minb: int - maxb: int + oneof: Set[int] - def __init__(self, *, minb: int = 1, maxb: int = 64) -> None: - assert minb is not None or maxb is not None - assert maxb <= 64 # For now, support up to 64 bits values + def __init__(self, *, oneof: Optional[Iterable[int]] = None, minb: Optional[int] = None, maxb: Optional[int] = None) -> None: + # For now, support up to 64 bits values + self.oneof = set(oneof) if oneof is not None else set(range(1, 65)) - self.minb = minb - self.maxb = maxb + if minb is not None: + self.oneof = { + x + for x in self.oneof + if minb <= x + } + + if maxb is not None: + self.oneof = { + x + for x in self.oneof + if x <= maxb + } def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': if not isinstance(other, TypeConstraintBitWidth): raise Exception('Invalid comparison') - if self.minb > other.maxb: - raise TypingNarrowProtoError('Min bitwidth exceeds other max bitwidth') + new_oneof = self.oneof & other.oneof - if other.minb > self.maxb: - raise TypingNarrowProtoError('Other min bitwidth exceeds max bitwidth') + if not new_oneof: + raise TypingNarrowProtoError('Memory width cannot be resolved') - return TypeConstraintBitWidth( - minb=max(self.minb, other.minb), - maxb=min(self.maxb, other.maxb), - ) + return TypeConstraintBitWidth(oneof=new_oneof) def __repr__(self) -> str: - return f'BitWidth={self.minb}..{self.maxb}' + return 'BitWidth=oneof(' + ','.join(map(str, sorted(self.oneof))) + ')' class TypeVar: """ @@ -380,11 +386,15 @@ def simplify(inp: TypeVar) -> Optional[str]: assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint assert isinstance(tc_sign, TypeConstraintSigned) # type hint - if tc_sign.signed is None or tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (8, 32, 64): + if tc_sign.signed is None or len(tc_bits.oneof) != 1: + return None + + bitwidth = next(iter(tc_bits.oneof)) + if bitwidth not in (8, 32, 64): return None base = 'i' if tc_sign.signed else 'u' - return f'{base}{tc_bits.minb}' + return f'{base}{bitwidth}' if primitive is TypeConstraintPrimitive.Primitive.FLOAT: if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint @@ -392,10 +402,14 @@ def simplify(inp: TypeVar) -> Optional[str]: assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint - if tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (32, 64): + if len(tc_bits.oneof) != 1: return None - return f'f{tc_bits.minb}' + bitwidth = next(iter(tc_bits.oneof)) + if bitwidth not in (32, 64): + return None + + return f'f{bitwidth}' return None diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py index d0eeb8f..d63d4ac 100644 --- a/tests/integration/test_lang/test_primitives.py +++ b/tests/integration/test_lang/test_primitives.py @@ -76,8 +76,8 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_logical_right_shift(type_): +@pytest.mark.parametrize('type_', ['u32', 'u64']) +def test_logical_right_shift_left_bit_zero(type_): code_py = f""" @exported def testEntry() -> {type_}: @@ -89,6 +89,18 @@ def testEntry() -> {type_}: assert 1 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +def test_logical_right_shift_left_bit_one(): + code_py = """ +@exported +def testEntry() -> u32: + return 4294967295 >> 16 +""" + + result = Suite(code_py).run_code() + + assert 0xFFFF == result.returned_value + @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) def test_bitwise_or(type_):