Adds a separte typing system #3

Closed
jbwdevries wants to merge 18 commits from milner_type_checking into master
5 changed files with 148 additions and 53 deletions
Showing only changes of commit 4d3c0c6c3c - Show all commits

View File

@ -95,18 +95,16 @@ def expression(inp: ourlang.Expression) -> str:
# return f'({args}, )' # return f'({args}, )'
return f'{inp.function.name}({args})' return f'{inp.function.name}({args})'
#
# if isinstance(inp, ourlang.AccessBytesIndex): if isinstance(inp, ourlang.Subscript):
# return f'{expression(inp.varref)}[{expression(inp.index)}]' varref = expression(inp.varref)
# index = expression(inp.index)
return f'{varref}[{index}]'
# TODO: Broken after new type system
# if isinstance(inp, ourlang.AccessStructMember): # if isinstance(inp, ourlang.AccessStructMember):
# return f'{expression(inp.varref)}.{inp.member.name}' # return f'{expression(inp.varref)}.{inp.member.name}'
#
# if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )):
# if isinstance(inp.member, ourlang.Expression):
# return f'{expression(inp.varref)}[{expression(inp.member)}]'
#
# return f'{expression(inp.varref)}[{inp.member.idx}]'
if isinstance(inp, ourlang.Fold): if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'

View File

@ -24,6 +24,9 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module:
def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType: def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
""" """
Compile: type Compile: type
Types are used for example in WebAssembly function parameters
and return types.
""" """
assert inp is not None, typing.ASSERTION_ERROR assert inp is not None, typing.ASSERTION_ERROR
@ -52,6 +55,16 @@ def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
if mtyp == 'f64': if mtyp == 'f64':
return wasm.WasmTypeFloat64() return wasm.WasmTypeFloat64()
assert inp is not None, typing.ASSERTION_ERROR
tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(type_var, inp)
if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# StaticArray, Tuples and Structs are passed as pointer
# And pointers are i32
return wasm.WasmTypeInt32()
# TODO: Broken after new type system # TODO: Broken after new type system
# if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): # if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
# # Structs and tuples are passed as pointer # # Structs and tuples are passed as pointer
@ -161,7 +174,12 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
return return
if isinstance(inp.variable, ourlang.ModuleConstantDef): if isinstance(inp.variable, ourlang.ModuleConstantDef):
# FIXME: Tuple / Static Array broken after new type system assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(expression, inp, inp.variable.type_var)
# TODO: Broken after new type system
# if isinstance(inp.type, typing.TypeTuple): # if isinstance(inp.type, typing.TypeTuple):
# assert isinstance(inp.definition.constant, ourlang.ConstantTuple) # assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
# assert inp.definition.data_block is not None, 'Combined values are memory stored' # assert inp.definition.data_block is not None, 'Combined values are memory stored'
@ -169,12 +187,12 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# wgn.i32.const(inp.definition.data_block.address) # wgn.i32.const(inp.definition.data_block.address)
# return # return
# #
# if isinstance(inp.type, typing.TypeStaticArray):
# assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# assert inp.definition.data_block is not None, 'Combined values are memory stored' assert inp.variable.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated' assert inp.variable.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address) wgn.i32.const(inp.variable.data_block.address)
# return return
assert inp.variable.data_block is None, 'Primitives are not memory stored' assert inp.variable.data_block is None, 'Primitives are not memory stored'
@ -276,6 +294,53 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
wgn.add_statement('call', '${}'.format(inp.function.name)) wgn.add_statement('call', '${}'.format(inp.function.name))
return return
if isinstance(inp, ourlang.Subscript):
assert inp.varref.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(expression, inp, inp.varref.type_var)
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
if not isinstance(inp.index, ourlang.ConstantPrimitive):
raise NotImplementedError(expression, inp, inp.index)
if not isinstance(inp.index.value, int):
raise NotImplementedError(expression, inp, inp.index.value)
assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp is None:
raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
if mtyp == 'u8':
# u8 operations are done using i32, since WASM does not have u8 operations
mtyp = 'i32'
elif mtyp == 'u32':
# u32 operations are done using i32, using _u operations
mtyp = 'i32'
elif mtyp == 'u64':
# u64 operations are done using i64, using _u operations
mtyp = 'i64'
tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
if tc_subs is None:
raise NotImplementedError(expression, inp, inp.varref.type_var)
assert 0 < len(tc_subs.members)
tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
if tc_bits is None or len(tc_bits.oneof) > 1:
raise NotImplementedError(expression, inp, inp.varref.type_var)
bitwidth = next(iter(tc_bits.oneof))
if bitwidth % 8 != 0:
raise NotImplementedError(expression, inp, inp.varref.type_var)
expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
return
raise NotImplementedError(expression, inp, inp.varref.type_var)
# TODO: Broken after new type system # TODO: Broken after new type system
# if isinstance(inp, ourlang.AccessBytesIndex): # if isinstance(inp, ourlang.AccessBytesIndex):
# if not isinstance(inp.type, typing.TypeUInt8): # if not isinstance(inp.type, typing.TypeUInt8):
@ -315,11 +380,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# # as members of static arrays # # as members of static arrays
# raise NotImplementedError(expression, inp, inp.member) # raise NotImplementedError(expression, inp, inp.member)
# #
# if isinstance(inp.member, typing.TypeStaticArrayMember):
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
# return
#
# expression(wgn, inp.varref) # expression(wgn, inp.varref)
# expression(wgn, inp.member) # expression(wgn, inp.member)
# wgn.i32.const(inp.static_array.member_type.alloc_size()) # wgn.i32.const(inp.static_array.member_type.alloc_size())

View File

@ -142,14 +142,11 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
expression(ctx, inp.varref) expression(ctx, inp.varref)
assert inp.varref.type_var is not None assert inp.varref.type_var is not None
try: # TODO: I'd much rather resolve this using the narrow functions
# TODO: I'd much rather resolve this using the narrow functions tc_subs = inp.varref.type_var.get_constraint(TypeConstraintSubscript)
tc_subs = ctx.var_constraints[inp.varref.type_var.ctx_id][TypeConstraintSubscript] if tc_subs is None:
except KeyError:
raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None
assert isinstance(tc_subs, TypeConstraintSubscript) # type hint
try: try:
# TODO: I'd much rather resolve this using the narrow functions # TODO: I'd much rather resolve this using the narrow functions
member = tc_subs.members[inp.index.value] member = tc_subs.members[inp.index.value]

View File

@ -2,6 +2,7 @@
The phasm type system The phasm type system
""" """
from typing import Callable, Dict, Iterable, Optional, List, Set, Type from typing import Callable, Dict, Iterable, Optional, List, Set, Type
from typing import TypeVar as MyPyTypeVar
import enum import enum
import re import re
@ -168,6 +169,8 @@ class TypeConstraintPrimitive(TypeConstraintBase):
INT = 0 INT = 0
FLOAT = 1 FLOAT = 1
STATIC_ARRAY = 10
primitive: Primitive primitive: Primitive
def __init__(self, primitive: Primitive) -> None: def __init__(self, primitive: Primitive) -> None:
@ -307,6 +310,8 @@ class TypeConstraintSubscript(TypeConstraintBase):
def __repr__(self) -> str: def __repr__(self) -> str:
return 'Subscript=(' + ','.join(map(repr, self.members)) + ')' return 'Subscript=(' + ','.join(map(repr, self.members)) + ')'
TTypeConstraintClass = MyPyTypeVar('TTypeConstraintClass', bound=TypeConstraintBase)
class TypeVar: class TypeVar:
""" """
A type variable A type variable
@ -329,15 +334,22 @@ class TypeVar:
else: else:
csts[newconst.__class__] = newconst csts[newconst.__class__] = newconst
def get_constraint(self, const_type: Type[TTypeConstraintClass]) -> Optional[TTypeConstraintClass]:
csts = self.ctx.var_constraints[self.ctx_id]
res = csts.get(const_type, None)
assert res is None or isinstance(res, const_type) # type hint
return res
def add_location(self, ref: str) -> None: def add_location(self, ref: str) -> None:
self.ctx.var_locations[self.ctx_id].append(ref) self.ctx.var_locations[self.ctx_id].add(ref)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
'TypeVar<' 'TypeVar<'
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values())) + '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
+ '; locations: ' + '; locations: '
+ ', '.join(self.ctx.var_locations[self.ctx_id]) + ', '.join(sorted(self.ctx.var_locations[self.ctx_id]))
+ '>' + '>'
) )
@ -356,7 +368,7 @@ class Context:
# Store the TypeVar properties as a lookup # Store the TypeVar properties as a lookup
# so we can update these when unifying # so we can update these when unifying
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {} self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
self.var_locations: Dict[int, List[str]] = {} self.var_locations: Dict[int, Set[str]] = {}
def new_var(self) -> TypeVar: def new_var(self) -> TypeVar:
ctx_id = self.next_ctx_id ctx_id = self.next_ctx_id
@ -366,7 +378,7 @@ class Context:
self.vars_by_id[ctx_id] = [result] self.vars_by_id[ctx_id] = [result]
self.var_constraints[ctx_id] = {} self.var_constraints[ctx_id] = {}
self.var_locations[ctx_id] = [] self.var_locations[ctx_id] = set()
return result return result
@ -395,8 +407,7 @@ class Context:
except TypingNarrowProtoError as exc: except TypingNarrowProtoError as exc:
raise TypingNarrowError(l, r, str(exc)) from None raise TypingNarrowError(l, r, str(exc)) from None
self.var_locations[n.ctx_id].extend(self.var_locations[l_ctx_id]) self.var_locations[n.ctx_id] = self.var_locations[l_ctx_id] | self.var_locations[r_ctx_id]
self.var_locations[n.ctx_id].extend(self.var_locations[r_ctx_id])
# ## # ##
# And unify (or entangle) the old ones # And unify (or entangle) the old ones
@ -424,22 +435,18 @@ def simplify(inp: TypeVar) -> Optional[str]:
Should round trip with from_str Should round trip with from_str
""" """
tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) tc_prim = inp.get_constraint(TypeConstraintPrimitive)
tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) tc_bits = inp.get_constraint(TypeConstraintBitWidth)
tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned) tc_sign = inp.get_constraint(TypeConstraintSigned)
if tc_prim is None: if tc_prim is None:
return None return None
assert isinstance(tc_prim, TypeConstraintPrimitive) # type hint
primitive = tc_prim.primitive primitive = tc_prim.primitive
if primitive is TypeConstraintPrimitive.Primitive.INT: if primitive is TypeConstraintPrimitive.Primitive.INT:
if tc_bits is None or tc_sign is None: if tc_bits is None or tc_sign is None:
return None return None
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
assert isinstance(tc_sign, TypeConstraintSigned) # type hint
if tc_sign.signed is None or len(tc_bits.oneof) != 1: if tc_sign.signed is None or len(tc_bits.oneof) != 1:
return None return None
@ -454,8 +461,6 @@ def simplify(inp: TypeVar) -> Optional[str]:
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
return None return None
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
if len(tc_bits.oneof) != 1: if len(tc_bits.oneof) != 1:
return None return None
@ -465,6 +470,17 @@ def simplify(inp: TypeVar) -> Optional[str]:
return f'f{bitwidth}' return f'f{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
tc_subs = inp.get_constraint(TypeConstraintSubscript)
assert tc_subs is not None
assert tc_subs.members
sab = simplify(tc_subs.members[0])
if sab is None:
return None
return f'{sab}[{len(tc_subs.members)}]'
return None return None
def make_u8(ctx: Context, location: str) -> TypeVar: def make_u8(ctx: Context, location: str) -> TypeVar:
@ -538,7 +554,7 @@ def make_f64(ctx: Context, location: str) -> TypeVar:
""" """
result = ctx.new_var() result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location(location) result.add_location(location)
return result return result
@ -573,6 +589,7 @@ def from_str(ctx: Context, inp: str, location: str) -> TypeVar:
if match: if match:
result = ctx.new_var() result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
result.add_constraint(TypeConstraintSubscript(members=( result.add_constraint(TypeConstraintSubscript(members=(
# Make copies so they don't get entangled # Make copies so they don't get entangled
# with each other. # with each other.

View File

@ -2,7 +2,9 @@ import pytest
from phasm.exceptions import StaticError, TypingError from phasm.exceptions import StaticError, TypingError
from ..constants import ALL_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP from ..constants import (
ALL_FLOAT_TYPES, ALL_INT_TYPES, COMPLETE_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP
)
from ..helpers import Suite from ..helpers import Suite
@pytest.mark.integration_test @pytest.mark.integration_test
@ -22,6 +24,7 @@ def testEntry() -> {type_}:
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.skip('To decide: What to do on out of index?')
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) @pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_static_array_indexed(type_): def test_static_array_indexed(type_):
code_py = f""" code_py = f"""
@ -41,8 +44,8 @@ def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}:
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) @pytest.mark.parametrize('type_', COMPLETE_INT_TYPES)
def test_function_call(type_): def test_function_call_int(type_):
code_py = f""" code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, ) CONSTANT: {type_}[3] = (24, 57, 80, )
@ -59,6 +62,25 @@ def helper(array: {type_}[3]) -> {type_}:
assert 161 == result.returned_value assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES)
def test_function_call_float(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24.0, 57.5, 80.75, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT)
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
"""
result = Suite(code_py).run_code()
assert 162.25 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
def test_module_constant_type_mismatch_bitwidth(): def test_module_constant_type_mismatch_bitwidth():
code_py = """ code_py = """
@ -100,8 +122,8 @@ def test_static_array_constant_too_few_values():
CONSTANT: u8[3] = (24, 57, ) CONSTANT: u8[3] = (24, 57, )
""" """
with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): with pytest.raises(TypingError, match='Member count does not match'):
phasm_parse(code_py) Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
def test_static_array_constant_too_many_values(): def test_static_array_constant_too_many_values():
@ -109,8 +131,8 @@ def test_static_array_constant_too_many_values():
CONSTANT: u8[3] = (24, 57, 1, 1, ) CONSTANT: u8[3] = (24, 57, 1, 1, )
""" """
with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): with pytest.raises(TypingError, match='Member count does not match'):
phasm_parse(code_py) Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
def test_static_array_constant_type_mismatch(): def test_static_array_constant_type_mismatch():
@ -118,10 +140,11 @@ def test_static_array_constant_type_mismatch():
CONSTANT: u8[3] = (24, 4000, 1, ) CONSTANT: u8[3] = (24, 4000, 1, )
""" """
with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): with pytest.raises(TypingError, match='u8.*4000'):
phasm_parse(code_py) Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.skip('To decide: What to do on out of index?')
def test_static_array_index_out_of_bounds(): def test_static_array_index_out_of_bounds():
code_py = """ code_py = """
CONSTANT0: u32[3] = (24, 57, 80, ) CONSTANT0: u32[3] = (24, 57, 80, )