diff --git a/py2wasm/python.py b/py2wasm/python.py index 853fd25..9710051 100644 --- a/py2wasm/python.py +++ b/py2wasm/python.py @@ -2,7 +2,7 @@ Code for parsing the source (python-alike) code """ -from typing import List, Optional, Union, Tuple +from typing import Dict, Generator, List, Optional, Union, Tuple import ast @@ -10,6 +10,10 @@ from . import wasm # pylint: disable=C0103,R0201 +StatementGenerator = Generator[wasm.Statement, None, None] + +WLocals = Dict[str, str] + class Visitor: """ Class to visit a Python syntax tree @@ -27,7 +31,7 @@ class Visitor: for stmt in node.body: if isinstance(stmt, ast.FunctionDef): - wnode = self.visit_FunctionDef(stmt) + wnode = self.visit_FunctionDef(module, stmt) if isinstance(wnode, wasm.Import): module.imports.append(wnode) @@ -39,13 +43,18 @@ class Visitor: return module - def visit_FunctionDef(self, node: ast.FunctionDef) -> Union[wasm.Import, wasm.Function]: + def visit_FunctionDef( + self, + module: wasm.Module, + node: ast.FunctionDef, + ) -> Union[wasm.Import, wasm.Function]: """ A Python function definition with the @external decorator is returned as import. Other functions are returned as (exported) functions. - """ + Nested / dynamicly created functions are not yet supported + """ exported = False if node.decorator_list: @@ -68,40 +77,197 @@ class Visitor: assert not call.keywords - module, name, intname = _parse_import_decorator(node.name, call.args) + mod, name, intname = _parse_import_decorator(node.name, call.args) - return wasm.Import(module, name, intname, _parse_import_args(node.args)) + return wasm.Import(mod, name, intname, _parse_import_args(node.args)) + else: + raise NotImplementedError - result = _parse_annotation(node.returns) + result = None if node.returns is None else _parse_annotation(node.returns) - statements = [ - self.visit_stmt(node, stmt) - for stmt in node.body + assert not node.args.vararg + assert not node.args.kwonlyargs + assert not node.args.kw_defaults + assert not node.args.kwarg + assert not node.args.defaults + + params = [ + (a.arg, _parse_annotation(a.annotation), ) + for a in [ + *node.args.posonlyargs, + *node.args.args, + ] ] - return wasm.Function(node.name, exported, result, statements) + wlocals: WLocals = dict(params) - def visit_stmt(self, func: ast.FunctionDef, stmt: ast.stmt) -> wasm.Statement: + statements: List[wasm.Statement] = [] + for py_stmt in node.body: + statements.extend( + self.visit_stmt(module, node, wlocals, py_stmt) + ) + + return wasm.Function(node.name, exported, params, result, statements) + + def visit_stmt( + self, + module: wasm.Module, + func: ast.FunctionDef, + wlocals: WLocals, + stmt: ast.stmt, + ) -> StatementGenerator: """ - Visits a statement node + Visits a statement node within a function """ if isinstance(stmt, ast.Return): - return self.visit_Return(func, stmt) + return self.visit_Return(module, func, wlocals, stmt) - raise NotImplementedError + if isinstance(stmt, ast.Expr): + assert isinstance(stmt.value, ast.Call) - def visit_Return(self, func: ast.FunctionDef, stmt: ast.Return) -> wasm.Statement: + return self.visit_Call(module, wlocals, "None", stmt.value) + + raise NotImplementedError(stmt) + + def visit_Return( + self, + module: wasm.Module, + func: ast.FunctionDef, + wlocals: WLocals, + stmt: ast.Return, + ) -> StatementGenerator: """ Visits a statement node """ - assert isinstance(stmt.value, ast.Constant) - assert isinstance(stmt.value.value, int) + + assert stmt.value is not None return_type = _parse_annotation(func.returns) - return wasm.Statement( - '{}.const'.format(return_type), - str(stmt.value.value) + return self.visit_expr(module, wlocals, return_type, stmt.value) + + def visit_expr( + self, + module: wasm.Module, + wlocals: WLocals, + exp_type: str, + node: ast.expr, + ) -> StatementGenerator: + """ + Visit an expression node + """ + if isinstance(node, ast.BinOp): + return self.visit_BinOp(module, wlocals, exp_type, node) + + if isinstance(node, ast.Call): + return self.visit_Call(module, wlocals, exp_type, node) + + if isinstance(node, ast.Constant): + return self.visit_Constant(exp_type, node) + + if isinstance(node, ast.Name): + return self.visit_Name(wlocals, exp_type, node) + + raise NotImplementedError(node) + + def visit_BinOp( + self, + module: wasm.Module, + wlocals: WLocals, + exp_type: str, + node: ast.BinOp, + ) -> StatementGenerator: + """ + Visits a BinOp node as (part of) an expression + """ + yield from self.visit_expr(module, wlocals, exp_type, node.left) + yield from self.visit_expr(module, wlocals, exp_type, node.right) + + if isinstance(node.op, ast.Add): + yield wasm.Statement('{}.add'.format(exp_type)) + return + + raise NotImplementedError(node.op) + + def visit_Call( + self, + module: wasm.Module, + wlocals: WLocals, + exp_type: str, + node: ast.Call, + ) -> StatementGenerator: + """ + Visits a Call node as (part of) an expression + """ + assert isinstance(node.func, ast.Name) + assert not node.keywords + + called_name = node.func.id + + called_func_list = [ + x + for x in module.functions + if x.name == called_name + ] + if called_func_list: + assert len(called_func_list) == 1 + called_params = called_func_list[0].params + called_result = called_func_list[0].result + else: + called_import_list = [ + x + for x in module.imports + if x.name == node.func.id + ] + if called_import_list: + assert len(called_import_list) == 1 + called_params = called_import_list[0].params + called_result = called_import_list[0].result + else: + assert 1 == len(called_func_list), \ + 'Could not find function {}'.format(node.func.id) + + assert exp_type == called_result + + assert len(called_params) == len(node.args), \ + '{}:{} Function {} requires {} arguments, but {} are supplied'.format( + node.lineno, node.col_offset, + called_name, len(called_params), len(node.args), + ) + + for (_, exp_expr_type), expr in zip(called_params, node.args): + yield from self.visit_expr(module, wlocals, exp_expr_type, expr) + + yield wasm.Statement( + 'call', + '${}'.format(called_name), + ) + + def visit_Constant(self, exp_type: str, node: ast.Constant) -> StatementGenerator: + """ + Visits a Constant node as (part of) an expression + """ + if 'i32' == exp_type: + assert isinstance(node.value, int) + assert -2147483648 <= node.value <= 2147483647 + + yield wasm.Statement( + '{}.const'.format(exp_type), + str(node.value) + ) + return + + raise NotImplementedError(exp_type) + + def visit_Name(self, wlocals: WLocals, exp_type: str, node: ast.Name) -> StatementGenerator: + """ + Visits a Name node as (part of) an expression + """ + assert node.id in wlocals + assert exp_type == wlocals[node.id] + yield wasm.Statement( + 'local.get', + '${}'.format(node.id), ) def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]: @@ -156,99 +322,3 @@ def _parse_annotation(ann: Optional[ast.expr]) -> str: result = ann.id assert result in ['i32'] return result - - - - - # def visit_ImportFrom(self, node): - # for alias in node.names: - # self.imports.append(Import( - # node.module, - # alias.name, - # alias.asname, - # )) - # - # def generic_visit(self, node: Any) -> None: - # print(node) - # super().generic_visit(node) - # - # def visit_FunctionDef(self, node: ast.FunctionDef) -> None: - # is_export = False - # - # if node.decorator_list: - # assert 1 == len(node.decorator_list) - # - # call = node.decorator_list[0] - # if not isinstance(call, ast.Name): - # assert isinstance(call, ast.Call) - # - # assert 'external' == call.func.id - # assert 1 == len(call.args) - # assert isinstance(call.args[0].value, str) - # - # import_ = Import( - # call.args[0].value, - # node.name, - # node.name, - # ) - # - # import_.params = [ - # arg.annotation.id - # for arg in node.args.args - # ] - # - # self.module.imports.append(import_) - # return - # - # assert call.id == 'export' - # is_export = True - # - # func = Function( - # node.name, - # is_export, - # ) - # - # for arg in node.args.args: - # func.params.append( - # (arg.arg, arg.annotation.id) - # ) - # func.result = node.returns - # - # self._stack.append(func) - # self.generic_visit(node) - # self._stack.pop() - # - # self.module.functions.append(func) - # - # def visit_Call(self, node: ast.Call) -> None: - # self.generic_visit(node) - # - # func = self._stack[-1] - # func.statements.append( - # Statement('call', '$' + node.func.id) - # ) - # - # def visit_BinOp(self, node: ast.BinOp) -> None: - # self.generic_visit(node) - # - # func = self._stack[-1] - # - # if 'Add' == node.op.__class__.__name__: - # func.statements.append( - # Statement('i32.add') - # ) - # elif 'Mult' == node.op.__class__.__name__: - # func.statements.append( - # Statement('i32.mul') - # ) - # else: - # raise NotImplementedError - # - # def visit_Constant(self, node: ast.Constant) -> None: - # func = self._stack[-1] - # if isinstance(node.value, int): - # func.statements.append( - # Statement('i32.const', str(node.value)) - # ) - # - # self.generic_visit(node) diff --git a/py2wasm/wasm.py b/py2wasm/wasm.py index fd1b43a..24d1f46 100644 --- a/py2wasm/wasm.py +++ b/py2wasm/wasm.py @@ -4,6 +4,8 @@ Python classes for storing the representation of Web Assembly code from typing import Iterable, List, Optional, Tuple +Param = Tuple[str, str] + class Import: """ Represents a Web Assembly import @@ -13,12 +15,13 @@ class Import: module: str, name: str, intname: str, - params: Iterable[Tuple[str, str]], + params: Iterable[Param], ) -> None: self.module = module self.name = name self.intname = intname self.params = [*params] + self.result: str = 'None' def generate(self) -> str: """ @@ -53,12 +56,13 @@ class Function: self, name: str, exported: bool, + params: Iterable[Param], result: Optional[str], statements: Iterable[Statement], ) -> None: self.name = name self.exported = exported - self.params: List[Tuple[str, str]] = [] + self.params = [*params] self.result = result self.statements = [*statements] diff --git a/pylintrc b/pylintrc index 72d74db..eb6abcb 100644 --- a/pylintrc +++ b/pylintrc @@ -1,2 +1,2 @@ [MASTER] -disable=C0122,R0903 +disable=C0122,R0903,R0913 diff --git a/tests/integration/test_fib.py b/tests/integration/test_fib.py index fa1eaf9..ff00b05 100644 --- a/tests/integration/test_fib.py +++ b/tests/integration/test_fib.py @@ -2,19 +2,26 @@ from .helpers import Suite def test_fib(): code_py = """ -@imported('console') -def logInt32(value: i32) -> None: - ... +def helper(n: i32, a: i32, b: i32) -> i32: + if n < 1: + return a + b -def helper(value: i32) -> i32: - return value + 1 + return helper(n - 1, a + b, a) + +def fib(n: i32) -> i32: + if n == 0: + return 0 + + if n == 1: + return 1 + + return helper(n - 1, 0, 1) @exported def testEntry(): - logInt32(13 + 13 * helper(122)) + return fib(40) """ result = Suite(code_py, 'test_fib').run_code() - assert None is result.returned_value - assert [1612] == result.log_int32_list + assert 102334155 == result.returned_value diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index 99ce3ce..0d015b4 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -9,5 +9,32 @@ def testEntry() -> i32: result = Suite(code_py, 'test_fib').run_code() - assert 13 is result.returned_value + assert 13 == result.returned_value + assert [] == result.log_int32_list + +def test_addition(): + code_py = """ +@exported +def testEntry() -> i32: + return 10 + 3 +""" + + result = Suite(code_py, 'test_fib').run_code() + + assert 13 == result.returned_value + assert [] == result.log_int32_list + +def test_call(): + code_py = """ +def helper(left: i32, right: i32) -> i32: + return left + right + +@exported +def testEntry() -> i32: + return helper(10, 3) +""" + + result = Suite(code_py, 'test_fib').run_code() + + assert 13 == result.returned_value assert [] == result.log_int32_list