From 75d7e0551918a3199563fe2d94715c77dd18fd8f Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Tue, 16 Aug 2022 20:38:41 +0200 Subject: [PATCH] First uint cast, more options for folding --- phasm/codestyle.py | 3 +++ phasm/compiler.py | 5 +++++ phasm/parser.py | 32 +++++++++++++++++++++++++----- tests/integration/test_builtins.py | 15 ++++++++++++++ 4 files changed, 50 insertions(+), 5 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index f700d9b..e2e64f0 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -97,6 +97,9 @@ def expression(inp: ourlang.Expression) -> str: or inp.operator in ourlang.WEBASSEMBLY_BUILDIN_BYTES_OPS): return f'{inp.operator}({expression(inp.right)})' + if inp.operator == 'cast': + return f'{type_(inp.type)}({expression(inp.right)})' + return f'{inp.operator}{expression(inp.right)}' if isinstance(inp, ourlang.BinaryOp): diff --git a/phasm/compiler.py b/phasm/compiler.py index 46a9dfe..4f810b6 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -219,6 +219,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.i32.load() return + if inp.operator == 'cast': + if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): + # Nothing to do, you can use an u8 value as a u32 no problem + return + raise NotImplementedError(expression, inp.type, inp.operator) if isinstance(inp, ourlang.FunctionCall): diff --git a/phasm/parser.py b/phasm/parser.py index fdf63c0..14e1ba2 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -380,6 +380,20 @@ class OurVisitor: 'sqrt', self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), ) + elif node.func.id == 'u32': + if not isinstance(exp_type, TypeUInt32): + _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') + + if 1 != len(node.args): + _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') + + # FIXME: This is a stub, proper casting is todo + + return UnaryOp( + exp_type, + 'cast', + self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), + ) elif node.func.id == 'len': if not isinstance(exp_type, TypeInt32): _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') @@ -398,14 +412,9 @@ class OurVisitor: # In the future, we should probably infer the type of the second argument, # and use it as expected types for the other u8s and the Iterable[u8] (i.e. bytes) - if not isinstance(exp_type, TypeUInt8): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type} - not implemented yet') - if 3 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given') - t_u8 = module.types['u8'] - # TODO: This is not generic subnode = node.args[0] if not isinstance(subnode, ast.Name): @@ -415,6 +424,19 @@ class OurVisitor: if subnode.id not in module.functions: _raise_static_error(subnode, 'Reference to undefined function') func = module.functions[subnode.id] + if 2 != len(func.posonlyargs): + _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') + + if exp_type.__class__ != func.returns.__class__: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') + + if func.returns.__class__ != func.posonlyargs[0][1].__class__: + _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}') + + t_u8 = module.types['u8'] + + if t_u8.__class__ != func.posonlyargs[1][1].__class__: + _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') return Fold( exp_type, diff --git a/tests/integration/test_builtins.py b/tests/integration/test_builtins.py index 4d0035c..4b84197 100644 --- a/tests/integration/test_builtins.py +++ b/tests/integration/test_builtins.py @@ -63,3 +63,18 @@ def testEntry(a: bytes, b: bytes) -> u8: result = suite.run_code(b'\x55\x0F', b'\x33\x80') assert 233 == result.returned_value + +@pytest.mark.integration_test +def test_foldl_3(): + code_py = """ +def xor(l: u32, r: u8) -> u32: + return l ^ u32(r) + +@exported +def testEntry(a: bytes) -> u32: + return foldl(xor, 0, a) +""" + suite = Suite(code_py) + + result = suite.run_code(b'\x55\x0F\x33\x80') + assert 233 == result.returned_value