First work on restoring StaticArray

Removed the separate ModuleConstantRef since you can tell by the variable
property of VariableReference. We'll also add local variables there later
on.
This commit is contained in:
Johan B.W. de Vries 2022-09-19 12:15:03 +02:00
parent 299551db1b
commit 0097ce782d
6 changed files with 163 additions and 189 deletions

View File

@ -95,26 +95,23 @@ def expression(inp: ourlang.Expression) -> str:
# return f'({args}, )' # return f'({args}, )'
return f'{inp.function.name}({args})' return f'{inp.function.name}({args})'
#
if isinstance(inp, ourlang.AccessBytesIndex): # if isinstance(inp, ourlang.AccessBytesIndex):
return f'{expression(inp.varref)}[{expression(inp.index)}]' # return f'{expression(inp.varref)}[{expression(inp.index)}]'
#
if isinstance(inp, ourlang.AccessStructMember): # if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member.name}' # return f'{expression(inp.varref)}.{inp.member.name}'
#
if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): # if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )):
if isinstance(inp.member, ourlang.Expression): # if isinstance(inp.member, ourlang.Expression):
return f'{expression(inp.varref)}[{expression(inp.member)}]' # return f'{expression(inp.varref)}[{expression(inp.member)}]'
#
return f'{expression(inp.varref)}[{inp.member.idx}]' # return f'{expression(inp.varref)}[{inp.member.idx}]'
if isinstance(inp, ourlang.Fold): if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
if isinstance(inp, ourlang.ModuleConstantReference):
return inp.definition.name
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements: def statement(inp: ourlang.Statement) -> Statements:

View File

@ -156,8 +156,39 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
raise NotImplementedError(f'Constants with type {stp}') raise NotImplementedError(f'Constants with type {stp}')
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
wgn.add_statement('local.get', '${}'.format(inp.variable.name)) if isinstance(inp.variable, ourlang.FunctionParam):
return wgn.add_statement('local.get', '${}'.format(inp.variable.name))
return
if isinstance(inp.variable, ourlang.ModuleConstantDef):
# FIXME: Tuple / Static Array broken after new type system
# if isinstance(inp.type, typing.TypeTuple):
# assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
#
# if isinstance(inp.type, typing.TypeStaticArray):
# assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
assert inp.variable.data_block is None, 'Primitives are not memory stored'
assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.variable.type_var)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.type_var)
expression(wgn, inp.variable.constant)
return
raise NotImplementedError(expression, inp.variable)
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
expression(wgn, inp.left) expression(wgn, inp.left)
@ -301,34 +332,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
expression_fold(wgn, inp) expression_fold(wgn, inp)
return return
if isinstance(inp, ourlang.ModuleConstantReference):
# FIXME: Tuple / Static Array broken after new type system
# if isinstance(inp.type, typing.TypeTuple):
# assert isinstance(inp.definition.constant, ourlang.ConstantTuple)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
#
# if isinstance(inp.type, typing.TypeStaticArray):
# assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray)
# assert inp.definition.data_block is not None, 'Combined values are memory stored'
# assert inp.definition.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.definition.data_block.address)
# return
assert inp.definition.data_block is None, 'Primitives are not memory stored'
assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp is None:
# In the future might extend this by having structs or tuples
# as members of struct or tuples
raise NotImplementedError(expression, inp, inp.type_var)
expression(wgn, inp.definition.constant)
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
@ -566,38 +569,46 @@ def module_data(inp: ourlang.ModuleData) -> bytes:
data_list: List[bytes] = [] data_list: List[bytes] = []
raise NotImplementedError('Broken after new type system')
for constant in block.data: for constant in block.data:
if isinstance(constant, ourlang.ConstantUInt8): assert constant.type_var is not None
mtyp = typing.simplify(constant.type_var)
if mtyp == 'u8':
assert isinstance(constant.value, int)
data_list.append(module_data_u8(constant.value)) data_list.append(module_data_u8(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantUInt32): if mtyp == 'u32':
assert isinstance(constant.value, int)
data_list.append(module_data_u32(constant.value)) data_list.append(module_data_u32(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantUInt64): if mtyp == 'u64':
assert isinstance(constant.value, int)
data_list.append(module_data_u64(constant.value)) data_list.append(module_data_u64(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantInt32): if mtyp == 'i32':
assert isinstance(constant.value, int)
data_list.append(module_data_i32(constant.value)) data_list.append(module_data_i32(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantInt64): if mtyp == 'i64':
assert isinstance(constant.value, int)
data_list.append(module_data_i64(constant.value)) data_list.append(module_data_i64(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantFloat32): if mtyp == 'f32':
assert isinstance(constant.value, float)
data_list.append(module_data_f32(constant.value)) data_list.append(module_data_f32(constant.value))
continue continue
if isinstance(constant, ourlang.ConstantFloat64): if mtyp == 'f64':
assert isinstance(constant.value, float)
data_list.append(module_data_f64(constant.value)) data_list.append(module_data_f64(constant.value))
continue continue
raise NotImplementedError(constant) raise NotImplementedError(constant, mtyp)
block_data = b''.join(data_list) block_data = b''.join(data_list)

View File

@ -11,10 +11,7 @@ WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc',
WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', )
from .typing import ( from .typing import (
TypeBytes, TypeStruct,
TypeTuple, TypeTupleMember,
TypeStaticArray, TypeStaticArrayMember,
TypeStruct, TypeStructMember,
TypeVar, TypeVar,
) )
@ -78,9 +75,9 @@ class VariableReference(Expression):
""" """
__slots__ = ('variable', ) __slots__ = ('variable', )
variable: 'FunctionParam' # also possibly local variable: Union['ModuleConstantDef', 'FunctionParam'] # also possibly local
def __init__(self, variable: 'FunctionParam') -> None: def __init__(self, variable: Union['ModuleConstantDef', 'FunctionParam']) -> None:
super().__init__() super().__init__()
self.variable = variable self.variable = variable
@ -131,9 +128,10 @@ class FunctionCall(Expression):
self.function = function self.function = function
self.arguments = [] self.arguments = []
class AccessBytesIndex(Expression): class Subscript(Expression):
""" """
Access a bytes index for reading A subscript, for example to refer to a static array or tuple
by index
""" """
__slots__ = ('varref', 'index', ) __slots__ = ('varref', 'index', )
@ -146,53 +144,6 @@ class AccessBytesIndex(Expression):
self.varref = varref self.varref = varref
self.index = index self.index = index
class AccessStructMember(Expression):
"""
Access a struct member for reading of writing
"""
__slots__ = ('varref', 'member', )
varref: VariableReference
member: TypeStructMember
def __init__(self, varref: VariableReference, member: TypeStructMember) -> None:
super().__init__()
self.varref = varref
self.member = member
class AccessTupleMember(Expression):
"""
Access a tuple member for reading of writing
"""
__slots__ = ('varref', 'member', )
varref: VariableReference
member: TypeTupleMember
def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None:
super().__init__()
self.varref = varref
self.member = member
class AccessStaticArrayMember(Expression):
"""
Access a tuple member for reading of writing
"""
__slots__ = ('varref', 'static_array', 'member', )
varref: Union['ModuleConstantReference', VariableReference]
static_array: TypeStaticArray
member: Union[Expression, TypeStaticArrayMember]
def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None:
super().__init__()
self.varref = varref
self.static_array = static_array
self.member = member
class Fold(Expression): class Fold(Expression):
""" """
A (left or right) fold A (left or right) fold
@ -223,18 +174,6 @@ class Fold(Expression):
self.base = base self.base = base
self.iter = iter_ self.iter = iter_
class ModuleConstantReference(Expression):
"""
An reference to a module constant expression within a statement
"""
__slots__ = ('definition', )
definition: 'ModuleConstantDef'
def __init__(self, definition: 'ModuleConstantDef') -> None:
super().__init__()
self.definition = definition
class Statement: class Statement:
""" """
A statement within a function A statement within a function

View File

@ -22,15 +22,14 @@ from .ourlang import (
Function, Function,
Expression, Expression,
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
BinaryOp, BinaryOp,
ConstantPrimitive, ConstantTuple, ConstantStaticArray, ConstantPrimitive, ConstantTuple, ConstantStaticArray,
FunctionCall, FunctionCall, Subscript,
# StructConstructor, TupleConstructor, # StructConstructor, TupleConstructor,
UnaryOp, VariableReference, UnaryOp, VariableReference,
Fold, ModuleConstantReference, Fold,
Statement, Statement,
StatementIf, StatementPass, StatementReturn, StatementIf, StatementPass, StatementReturn,
@ -206,6 +205,27 @@ class OurVisitor:
None, None,
) )
if isinstance(node.value, ast.Tuple):
tuple_data = [
self.visit_Module_Constant(module, arg_node)
for arg_node in node.value.elts
if isinstance(arg_node, ast.Constant)
]
if len(node.value.elts) != len(tuple_data):
_raise_static_error(node, 'Tuple arguments must be constants')
# Allocate the data
data_block = ModuleDataBlock(tuple_data)
module.data.blocks.append(data_block)
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
ConstantTuple(tuple_data),
data_block,
)
raise NotImplementedError('TODO: Broken after new typing system') raise NotImplementedError('TODO: Broken after new typing system')
# if isinstance(exp_type, TypeTuple): # if isinstance(exp_type, TypeTuple):
@ -416,7 +436,7 @@ class OurVisitor:
if node.id in module.constant_defs: if node.id in module.constant_defs:
cdef = module.constant_defs[node.id] cdef = module.constant_defs[node.id]
return ModuleConstantReference(cdef) return VariableReference(cdef)
_raise_static_error(node, f'Undefined variable {node.id}') _raise_static_error(node, f'Undefined variable {node.id}')
@ -454,13 +474,13 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load): if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.func.id in module.structs: # if node.func.id in module.structs:
raise NotImplementedError('TODO: Broken after new type system') # raise NotImplementedError('TODO: Broken after new type system')
# struct = module.structs[node.func.id] # struct = module.structs[node.func.id]
# struct_constructor = StructConstructor(struct) # struct_constructor = StructConstructor(struct)
# #
# func = module.functions[struct_constructor.name] # func = module.functions[struct_constructor.name]
elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: if node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if 1 != len(node.args): if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
@ -533,61 +553,59 @@ class OurVisitor:
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
raise NotImplementedError('Broken after new type system') raise NotImplementedError('Broken after new type system')
del module # del module
del function # del function
#
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}')
param = our_locals[node.value.id]
node_typ = param.type
if not isinstance(node_typ, TypeStruct):
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
member = node_typ.get_member(node.attr)
if member is None:
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
return AccessStructMember(
VariableReference(param),
member,
)
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
raise NotImplementedError('TODO: Broken after new type system')
# if not isinstance(node.value, ast.Name): # if not isinstance(node.value, ast.Name):
# _raise_static_error(node, 'Must reference a name') # _raise_static_error(node, 'Must reference a name')
# #
# if not isinstance(node.slice, ast.Index):
# _raise_static_error(node, 'Must subscript using an index')
#
# if not isinstance(node.ctx, ast.Load): # if not isinstance(node.ctx, ast.Load):
# _raise_static_error(node, 'Must be load context') # _raise_static_error(node, 'Must be load context')
# #
# varref: Union[ModuleConstantReference, VariableReference] # if not node.value.id in our_locals:
# if node.value.id in our_locals:
# param = our_locals[node.value.id]
# node_typ = param.type
# varref = VariableReference(param)
# elif node.value.id in module.constant_defs:
# constant_def = module.constant_defs[node.value.id]
# node_typ = constant_def.type
# varref = ModuleConstantReference(constant_def)
# else:
# _raise_static_error(node, f'Undefined variable {node.value.id}') # _raise_static_error(node, f'Undefined variable {node.value.id}')
# #
# slice_expr = self.visit_Module_FunctionDef_expr( # param = our_locals[node.value.id]
# module, function, our_locals, node.slice.value,
# )
# #
# node_typ = param.type
# if not isinstance(node_typ, TypeStruct):
# _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
#
# member = node_typ.get_member(node.attr)
# if member is None:
# _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
#
# return AccessStructMember(
# VariableReference(param),
# member,
# )
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
varref: VariableReference
if node.value.id in our_locals:
param = our_locals[node.value.id]
varref = VariableReference(param)
elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id]
varref = VariableReference(constant_def)
else:
_raise_static_error(node, f'Undefined variable {node.value.id}')
slice_expr = self.visit_Module_FunctionDef_expr(
module, function, our_locals, node.slice.value,
)
return Subscript(varref, slice_expr)
# if isinstance(node_typ, TypeBytes): # if isinstance(node_typ, TypeBytes):
# if isinstance(varref, ModuleConstantReference): # if isinstance(varref, ModuleConstantReference):
# raise NotImplementedError(f'{node} from module constant') # raise NotImplementedError(f'{node} from module constant')

View File

@ -62,7 +62,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
return constant(ctx, inp) return constant(ctx, inp)
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
assert inp.variable.type_var is not None, inp assert inp.variable.type_var is not None
return inp.variable.type_var return inp.variable.type_var
if isinstance(inp, ourlang.UnaryOp): if isinstance(inp, ourlang.UnaryOp):
@ -112,13 +112,6 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
return inp.function.returns_type_var return inp.function.returns_type_var
if isinstance(inp, ourlang.ModuleConstantReference):
assert inp.definition.type_var is not None
inp.type_var = inp.definition.type_var
return inp.definition.type_var
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def function(ctx: Context, inp: ourlang.Function) -> None: def function(ctx: Context, inp: ourlang.Function) -> None:

View File

@ -7,21 +7,18 @@ from ..helpers import Suite
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) @pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_static_array_module_constant(type_): def test_module_constant(type_):
code_py = f""" code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, ) CONSTANT: {type_}[3] = (24, 57, 80, )
@exported @exported
def testEntry() -> {type_}: def testEntry() -> {type_}:
return helper(CONSTANT) return CONSTANT[0]
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
""" """
result = Suite(code_py).run_code() result = Suite(code_py).run_code()
assert 161 == result.returned_value assert 24 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
@ -43,6 +40,25 @@ def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}:
assert 161 == result.returned_value assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value) assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test
@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES)
def test_function_call(type_):
code_py = f"""
CONSTANT: {type_}[3] = (24, 57, 80, )
@exported
def testEntry() -> {type_}:
return helper(CONSTANT)
def helper(array: {type_}[3]) -> {type_}:
return array[0] + array[1] + array[2]
"""
result = Suite(code_py).run_code()
assert 161 == result.returned_value
assert TYPE_MAP[type_] == type(result.returned_value)
@pytest.mark.integration_test @pytest.mark.integration_test
def test_static_array_constant_too_few_values(): def test_static_array_constant_too_few_values():
code_py = """ code_py = """