diff --git a/py2wasm/compiler.py b/py2wasm/compiler.py index eba158c..d04bb2c 100644 --- a/py2wasm/compiler.py +++ b/py2wasm/compiler.py @@ -87,6 +87,20 @@ def expression(inp: ourlang.Expression) -> Statements: raise NotImplementedError(expression, inp.type, inp.operator) + if isinstance(inp, ourlang.UnaryOp): + yield from expression(inp.right) + + if isinstance(inp.type, ourlang.OurTypeFloat32): + if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + yield wasm.Statement(f'f32.{inp.operator}') + return + if isinstance(inp.type, ourlang.OurTypeFloat64): + if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + yield wasm.Statement(f'f64.{inp.operator}') + return + + raise NotImplementedError(expression, inp.type, inp.operator) + if isinstance(inp, ourlang.FunctionCall): for arg in inp.arguments: yield from expression(arg) diff --git a/py2wasm/ourlang.py b/py2wasm/ourlang.py index 0d53280..1855e21 100644 --- a/py2wasm/ourlang.py +++ b/py2wasm/ourlang.py @@ -5,6 +5,10 @@ from typing import Any, Dict, List, Optional, NoReturn, Union, Tuple import ast +from typing_extensions import Final + +WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) + class OurType: """ Type base class @@ -224,6 +228,9 @@ class UnaryOp(Expression): self.right = right def render(self) -> str: + if self.operator in WEBASSEMBLY_BUILDIN_FLOAT_OPS: + return f'{self.operator}({self.right.render()})' + return f'{self.operator}{self.right.render()}' class FunctionCall(Expression): @@ -523,15 +530,6 @@ class Module: self.functions = {} self.structs = {} - # sqrt is guaranteed by wasm, so we should provide it - # ATM it's a 32 bit variant, but that might change. - # TODO: Could be a UnaryOp? - sqrt = Function('sqrt', -2) - sqrt.buildin = True - sqrt.returns = self.types['f32'] - sqrt.posonlyargs = [('@', self.types['f32'], )] - # self.functions[sqrt.name] = sqrt - def render(self) -> str: """ Renders the module back to source code format @@ -834,7 +832,7 @@ class OurVisitor: raise NotImplementedError(f'{node} as expr in FunctionDef') - def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.Call) -> FunctionCall: + def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.Call) -> Union[FunctionCall, UnaryOp]: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? @@ -848,6 +846,18 @@ class OurVisitor: struct_constructor = StructConstructor(struct) func = module.functions[struct_constructor.name] + elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if not isinstance(exp_type, (OurTypeFloat32, OurTypeFloat64, )): + _raise_static_error(node, f'Cannot make square root result in {exp_type}') + + if 1 != len(node.args): + _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') + + return UnaryOp( + exp_type, + 'sqrt', + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), + ) else: if node.func.id not in module.functions: _raise_static_error(node, 'Call to undefined function')