Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
4 changed files with 70 additions and 34 deletions
Showing only changes of commit 299551db1b - Show all commits

View File

@ -72,8 +72,6 @@ U8_OPERATOR_MAP = {
# Under the hood, this is an i32
# Implementing Right Shift XOR, OR, AND is fine since the 3 remaining
# bytes stay zero after this operation
# Since it's unsigned an unsigned value, Logical or Arithmetic shift right
# are the same operation
'>>': 'shr_u',
'^': 'xor',
'|': 'or',

View File

@ -77,17 +77,29 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
return right
if isinstance(inp, ourlang.BinaryOp):
# TODO: Simplified version
if inp.operator not in ('+', '-', '*', '|', '&', '^'):
raise NotImplementedError(expression, inp, inp.operator)
if inp.operator in ('+', '-', '*', '|', '&', '^'):
left = expression(ctx, inp.left)
right = expression(ctx, inp.right)
ctx.unify(left, right)
left = expression(ctx, inp.left)
right = expression(ctx, inp.right)
ctx.unify(left, right)
inp.type_var = left
return left
inp.type_var = left
if inp.operator in ('<<', '>>', ):
inp.type_var = ctx.new_var()
inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, )))
inp.type_var.add_constraint(TypeConstraintSigned(False))
return left
left = expression(ctx, inp.left)
right = expression(ctx, inp.right)
ctx.unify(left, right)
ctx.unify(inp.type_var, left)
return left
raise NotImplementedError(expression, inp, inp.operator)
if isinstance(inp, ourlang.FunctionCall):
assert inp.function.returns_type_var is not None

View File

@ -1,7 +1,7 @@
"""
The phasm type system
"""
from typing import Dict, Optional, List, Type
from typing import Dict, Iterable, Optional, List, Set, Type
import enum
@ -217,35 +217,41 @@ class TypeConstraintBitWidth(TypeConstraintBase):
"""
Contraint on how many bits an expression has or can possibly have
"""
__slots__ = ('minb', 'maxb', )
__slots__ = ('oneof', )
minb: int
maxb: int
oneof: Set[int]
def __init__(self, *, minb: int = 1, maxb: int = 64) -> None:
assert minb is not None or maxb is not None
assert maxb <= 64 # For now, support up to 64 bits values
def __init__(self, *, oneof: Optional[Iterable[int]] = None, minb: Optional[int] = None, maxb: Optional[int] = None) -> None:
# For now, support up to 64 bits values
self.oneof = set(oneof) if oneof is not None else set(range(1, 65))
self.minb = minb
self.maxb = maxb
if minb is not None:
self.oneof = {
x
for x in self.oneof
if minb <= x
}
if maxb is not None:
self.oneof = {
x
for x in self.oneof
if x <= maxb
}
def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth':
if not isinstance(other, TypeConstraintBitWidth):
raise Exception('Invalid comparison')
if self.minb > other.maxb:
raise TypingNarrowProtoError('Min bitwidth exceeds other max bitwidth')
new_oneof = self.oneof & other.oneof
if other.minb > self.maxb:
raise TypingNarrowProtoError('Other min bitwidth exceeds max bitwidth')
if not new_oneof:
raise TypingNarrowProtoError('Memory width cannot be resolved')
return TypeConstraintBitWidth(
minb=max(self.minb, other.minb),
maxb=min(self.maxb, other.maxb),
)
return TypeConstraintBitWidth(oneof=new_oneof)
def __repr__(self) -> str:
return f'BitWidth={self.minb}..{self.maxb}'
return 'BitWidth=oneof(' + ','.join(map(str, sorted(self.oneof))) + ')'
class TypeVar:
"""
@ -380,11 +386,15 @@ def simplify(inp: TypeVar) -> Optional[str]:
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
assert isinstance(tc_sign, TypeConstraintSigned) # type hint
if tc_sign.signed is None or tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (8, 32, 64):
if tc_sign.signed is None or len(tc_bits.oneof) != 1:
return None
bitwidth = next(iter(tc_bits.oneof))
if bitwidth not in (8, 32, 64):
return None
base = 'i' if tc_sign.signed else 'u'
return f'{base}{tc_bits.minb}'
return f'{base}{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.FLOAT:
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
@ -392,10 +402,14 @@ def simplify(inp: TypeVar) -> Optional[str]:
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
if tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (32, 64):
if len(tc_bits.oneof) != 1:
return None
return f'f{tc_bits.minb}'
bitwidth = next(iter(tc_bits.oneof))
if bitwidth not in (32, 64):
return None
return f'f{bitwidth}'
return None

View File

@ -76,8 +76,8 @@ def testEntry() -> {type_}:
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_logical_right_shift(type_):
@pytest.mark.parametrize('type_', ['u32', 'u64'])
def test_logical_right_shift_left_bit_zero(type_):
code_py = f"""
@exported
def testEntry() -> {type_}:
@ -89,6 +89,18 @@ def testEntry() -> {type_}:
assert 1 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_logical_right_shift_left_bit_one():
code_py = """
@exported
def testEntry() -> u32:
return 4294967295 >> 16
"""
result = Suite(code_py).run_code()
assert 0xFFFF == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64'])
def test_bitwise_or(type_):