From 109423fe34291be27adbe198f5b0f69dc3affad9 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 25 Jun 2022 18:56:15 +0200 Subject: [PATCH] List ideas. How do we keep it safe? --- py2wasm/ourlang.py | 83 ++++++++++++++++++++++++++++++++ tests/integration/test_simple.py | 18 +++++++ 2 files changed, 101 insertions(+) diff --git a/py2wasm/ourlang.py b/py2wasm/ourlang.py index ba87aec..665c8e9 100644 --- a/py2wasm/ourlang.py +++ b/py2wasm/ourlang.py @@ -29,6 +29,17 @@ class OurType: """ raise NotImplementedError(self, 'alloc_size') +class OurTypeAny(OurType): + """ + The Any type + + In places where the types really does not matter, such as len([array]) + """ + __slots__ = () + + def render(self) -> str: + raise Exception('Internal only') + class OurTypeNone(OurType): """ The None (or Void) type @@ -121,6 +132,20 @@ class OurTypeTuple(OurType): for x in self.members ) +class OurTypeList(OurType): + """ + The tuple type + """ + __slots__ = ('member_type', ) + + member_type: OurType + + def __init__(self, member_type: OurType) -> None: + self.member_type = member_type + + def render(self) -> str: + return f'[{self.member_type.render()}]' + class Expression: """ An expression within a statement @@ -260,6 +285,9 @@ class UnaryOp(Expression): if self.operator in WEBASSEMBLY_BUILDIN_FLOAT_OPS: return f'{self.operator}({self.right.render()})' + if self.operator == 'len': + return f'{self.operator}({self.right.render()})' + return f'{self.operator}{self.right.render()}' class FunctionCall(Expression): @@ -327,6 +355,23 @@ class AccessTupleMember(Expression): def render(self) -> str: return f'{self.varref.render()}[{self.member.idx}]' +class CreateList(Expression): + """ + Creates a list + """ + __slots__ = ('arguments', ) + + arguments: List[Expression] + + def __init__(self, type_: OurTypeList) -> None: + super().__init__(type_) + + self.arguments = [] + + def render(self) -> str: + mems = ', '.join(x.render() for x in self.arguments) + return f'[{mems}]' + class Statement: """ A statement within a function @@ -851,6 +896,20 @@ class OurVisitor: return VariableReference(act_type, node.id) + if isinstance(node, ast.List): + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if not isinstance(exp_type, OurTypeList): + _raise_static_error(node, f'Expression is expecting a {exp_type.render()}, not a list') + + result_cl = CreateList(exp_type) + result_cl.arguments = [ + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type.member_type, arg_node) + for arg_node in node.elts + ] + return result_cl + if isinstance(node, ast.Tuple): if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') @@ -900,6 +959,19 @@ class OurVisitor: 'sqrt', self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), ) + elif node.func.id == 'len': + if not isinstance(exp_type, OurTypeInt32): + _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') + + if 1 != len(node.args): + _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') + + return UnaryOp( + exp_type, + 'len', + self.visit_Module_FunctionDef_expr(module, function, our_locals, OurTypeList(OurTypeAny()), node.args[0]), + ) + else: if node.func.id not in module.functions: _raise_static_error(node, 'Call to undefined function') @@ -1034,6 +1106,17 @@ class OurVisitor: _raise_static_error(node, f'Unrecognized type {node.id}') + if isinstance(node, ast.List): + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if len(node.elts) != 1: + _raise_static_error(node, 'Must provide one type parameter') + + return OurTypeList( + self.visit_type(module, node.elts[0]) + ) + if isinstance(node, ast.Tuple): if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index ea281e7..36f97df 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -366,6 +366,24 @@ def helper(v: (f32, f32, f32, )) -> f32: assert 3.74 < result.returned_value < 3.75 assert [] == result.log_int32_list +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['f32', 'f64']) +# @pytest.mark.xfail(reason='Working on it') +def test_list(type_): + code_py = f""" +@exported +def testEntry() -> i32: + return helper([1.0, 2.0, 3.0]) + +def helper(v: [{type_}]) -> i32: + return len(v) +""" + + result = Suite(code_py, 'test_call').run_code() + + assert 3 == result.returned_value + assert [] == result.log_int32_list + @pytest.mark.integration_test @pytest.mark.skip('SIMD support is but a dream') def test_tuple_i32x4():