StaticArray with constant index works again

Also, fix issue with f64 being parsed as f32
This commit is contained in:
Johan B.W. de Vries 2022-09-19 14:43:15 +02:00
parent 5da45e78c2
commit 4d3c0c6c3c
5 changed files with 148 additions and 53 deletions

View File

@ -95,18 +95,16 @@ def expression(inp: ourlang.Expression) -> str:
# return f'({args}, )'
return f'{inp.function.name}({args})'
#
# if isinstance(inp, ourlang.AccessBytesIndex):
# return f'{expression(inp.varref)}[{expression(inp.index)}]'
#
if isinstance(inp, ourlang.Subscript):
varref = expression(inp.varref)
index = expression(inp.index)
return f'{varref}[{index}]'
# TODO: Broken after new type system
# if isinstance(inp, ourlang.AccessStructMember):
# return f'{expression(inp.varref)}.{inp.member.name}'
#
# if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )):
# if isinstance(inp.member, ourlang.Expression):
# return f'{expression(inp.varref)}[{expression(inp.member)}]'
#
# return f'{expression(inp.varref)}[{inp.member.idx}]'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'

View File

@ -24,6 +24,9 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module:
def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
"""
Compile: type
Types are used for example in WebAssembly function parameters
and return types.
"""
assert inp is not None, typing.ASSERTION_ERROR
@ -52,6 +55,16 @@ def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
if mtyp == 'f64':
return wasm.WasmTypeFloat64()
assert inp is not None, typing.ASSERTION_ERROR
tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(type_var, inp)
if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# StaticArray, Tuples and Structs are passed as pointer
# And pointers are i32
return wasm.WasmTypeInt32()
# TODO: Broken after new type system
# if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
# # Structs and tuples are passed as pointer
@ -161,7 +174,12 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
return
if isinstance(inp.variable, ourlang.ModuleConstantDef):
# FIXME: Tuple / Static Array broken after new type system
assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(expression, inp, inp.variable.type_var)
# TODO: Broken after new type system
# if isinstance(inp.type, typing.TypeTuple):
# assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
@ -169,12 +187,12 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# wgn.i32.const(inp.definition.data_block.address)
# return
#
# if isinstance(inp.type, typing.TypeStaticArray):
# assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
assert inp.variable.data_block is not None, 'Combined values are memory stored'
assert inp.variable.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.variable.data_block.address)
return
assert inp.variable.data_block is None, 'Primitives are not memory stored'
@ -276,6 +294,53 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
wgn.add_statement('call', '${}'.format(inp.function.name))
return
if isinstance(inp, ourlang.Subscript):
assert inp.varref.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
raise NotImplementedError(expression, inp, inp.varref.type_var)
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
if not isinstance(inp.index, ourlang.ConstantPrimitive):
raise NotImplementedError(expression, inp, inp.index)
if not isinstance(inp.index.value, int):
raise NotImplementedError(expression, inp, inp.index.value)
assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp is None:
raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
if mtyp == 'u8':
# u8 operations are done using i32, since WASM does not have u8 operations
mtyp = 'i32'
elif mtyp == 'u32':
# u32 operations are done using i32, using _u operations
mtyp = 'i32'
elif mtyp == 'u64':
# u64 operations are done using i64, using _u operations
mtyp = 'i64'
tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
if tc_subs is None:
raise NotImplementedError(expression, inp, inp.varref.type_var)
assert 0 < len(tc_subs.members)
tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
if tc_bits is None or len(tc_bits.oneof) > 1:
raise NotImplementedError(expression, inp, inp.varref.type_var)
bitwidth = next(iter(tc_bits.oneof))
if bitwidth % 8 != 0:
raise NotImplementedError(expression, inp, inp.varref.type_var)
expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
return
raise NotImplementedError(expression, inp, inp.varref.type_var)
# TODO: Broken after new type system
# if isinstance(inp, ourlang.AccessBytesIndex):
# if not isinstance(inp.type, typing.TypeUInt8):
@ -315,11 +380,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# # as members of static arrays
# raise NotImplementedError(expression, inp, inp.member)
#
# if isinstance(inp.member, typing.TypeStaticArrayMember):
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset))
# return
#
# expression(wgn, inp.varref)
# expression(wgn, inp.member)
# wgn.i32.const(inp.static_array.member_type.alloc_size())

View File

@ -142,14 +142,11 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
expression(ctx, inp.varref)
assert inp.varref.type_var is not None
try:
# TODO: I'd much rather resolve this using the narrow functions
tc_subs = ctx.var_constraints[inp.varref.type_var.ctx_id][TypeConstraintSubscript]
except KeyError:
# TODO: I'd much rather resolve this using the narrow functions
tc_subs = inp.varref.type_var.get_constraint(TypeConstraintSubscript)
if tc_subs is None:
raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None
assert isinstance(tc_subs, TypeConstraintSubscript) # type hint
try:
# TODO: I'd much rather resolve this using the narrow functions
member = tc_subs.members[inp.index.value]

View File

@ -2,6 +2,7 @@
The phasm type system
"""
from typing import Callable, Dict, Iterable, Optional, List, Set, Type
from typing import TypeVar as MyPyTypeVar
import enum
import re
@ -168,6 +169,8 @@ class TypeConstraintPrimitive(TypeConstraintBase):
INT = 0
FLOAT = 1
STATIC_ARRAY = 10
primitive: Primitive
def __init__(self, primitive: Primitive) -> None:
@ -307,6 +310,8 @@ class TypeConstraintSubscript(TypeConstraintBase):
def __repr__(self) -> str:
return 'Subscript=(' + ','.join(map(repr, self.members)) + ')'
TTypeConstraintClass = MyPyTypeVar('TTypeConstraintClass', bound=TypeConstraintBase)
class TypeVar:
"""
A type variable
@ -329,15 +334,22 @@ class TypeVar:
else:
csts[newconst.__class__] = newconst
def get_constraint(self, const_type: Type[TTypeConstraintClass]) -> Optional[TTypeConstraintClass]:
csts = self.ctx.var_constraints[self.ctx_id]
res = csts.get(const_type, None)
assert res is None or isinstance(res, const_type) # type hint
return res
def add_location(self, ref: str) -> None:
self.ctx.var_locations[self.ctx_id].append(ref)
self.ctx.var_locations[self.ctx_id].add(ref)
def __repr__(self) -> str:
return (
'TypeVar<'
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
+ '; locations: '
+ ', '.join(self.ctx.var_locations[self.ctx_id])
+ ', '.join(sorted(self.ctx.var_locations[self.ctx_id]))
+ '>'
)
@ -356,7 +368,7 @@ class Context:
# Store the TypeVar properties as a lookup
# so we can update these when unifying
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
self.var_locations: Dict[int, List[str]] = {}
self.var_locations: Dict[int, Set[str]] = {}
def new_var(self) -> TypeVar:
ctx_id = self.next_ctx_id
@ -366,7 +378,7 @@ class Context:
self.vars_by_id[ctx_id] = [result]
self.var_constraints[ctx_id] = {}
self.var_locations[ctx_id] = []
self.var_locations[ctx_id] = set()
return result
@ -395,8 +407,7 @@ class Context:
except TypingNarrowProtoError as exc:
raise TypingNarrowError(l, r, str(exc)) from None
self.var_locations[n.ctx_id].extend(self.var_locations[l_ctx_id])
self.var_locations[n.ctx_id].extend(self.var_locations[r_ctx_id])
self.var_locations[n.ctx_id] = self.var_locations[l_ctx_id] | self.var_locations[r_ctx_id]
# ##
# And unify (or entangle) the old ones
@ -424,22 +435,18 @@ def simplify(inp: TypeVar) -> Optional[str]:
Should round trip with from_str
"""
tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive)
tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth)
tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned)
tc_prim = inp.get_constraint(TypeConstraintPrimitive)
tc_bits = inp.get_constraint(TypeConstraintBitWidth)
tc_sign = inp.get_constraint(TypeConstraintSigned)
if tc_prim is None:
return None
assert isinstance(tc_prim, TypeConstraintPrimitive) # type hint
primitive = tc_prim.primitive
if primitive is TypeConstraintPrimitive.Primitive.INT:
if tc_bits is None or tc_sign is None:
return None
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
assert isinstance(tc_sign, TypeConstraintSigned) # type hint
if tc_sign.signed is None or len(tc_bits.oneof) != 1:
return None
@ -454,8 +461,6 @@ def simplify(inp: TypeVar) -> Optional[str]:
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
return None
assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint
if len(tc_bits.oneof) != 1:
return None
@ -465,6 +470,17 @@ def simplify(inp: TypeVar) -> Optional[str]:
return f'f{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
tc_subs = inp.get_constraint(TypeConstraintSubscript)
assert tc_subs is not None
assert tc_subs.members
sab = simplify(tc_subs.members[0])
if sab is None:
return None
return f'{sab}[{len(tc_subs.members)}]'
return None
def make_u8(ctx: Context, location: str) -> TypeVar:
@ -538,7 +554,7 @@ def make_f64(ctx: Context, location: str) -> TypeVar:
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location(location)
return result
@ -573,6 +589,7 @@ def from_str(ctx: Context, inp: str, location: str) -> TypeVar:
if match:
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
result.add_constraint(TypeConstraintSubscript(members=(
# Make copies so they don't get entangled
# with each other.

View File

@ -2,7 +2,9 @@ import pytest
from phasm.exceptions import StaticError, TypingError
from ..constants import ALL_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP
from ..constants import (
ALL_FLOAT_TYPES, ALL_INT_TYPES, COMPLETE_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP
)
from ..helpers import Suite
@pytest.mark.integration_test
@ -22,6 +24,7 @@ def testEntry() -> {type_}:
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.skip('To decide: What to do on out of index?')
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_static_array_indexed(type_):
code_py = f"""
@ -41,8 +44,8 @@ def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}:
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_function_call(type_):
@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES)
def test_function_call_int(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@ -59,6 +62,25 @@ def helper(array: {type_}[3]) -> {type_}:
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES)
def test_function_call_float(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24.0, 57.5, 80.75, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT)
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
"""
result = Suite(code_py).run_code()
assert 162.25 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
def test_module_constant_type_mismatch_bitwidth():
code_py = """
@ -100,8 +122,8 @@ def test_static_array_constant_too_few_values():
CONSTANT: u8[3] = (24, 57, )
"""
with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'):
phasm_parse(code_py)
with pytest.raises(TypingError, match='Member count does not match'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_static_array_constant_too_many_values():
@ -109,8 +131,8 @@ def test_static_array_constant_too_many_values():
CONSTANT: u8[3] = (24, 57, 1, 1, )
"""
with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'):
phasm_parse(code_py)
with pytest.raises(TypingError, match='Member count does not match'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_static_array_constant_type_mismatch():
@ -118,10 +140,11 @@ def test_static_array_constant_type_mismatch():
CONSTANT: u8[3] = (24, 4000, 1, )
"""
with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'):
phasm_parse(code_py)
with pytest.raises(TypingError, match='u8.*4000'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('To decide: What to do on out of index?')
def test_static_array_index_out_of_bounds():
code_py = """
CONSTANT0: u32[3] = (24, 57, 80, )