diff --git a/py2wasm/python.py b/py2wasm/python.py index da7c140..853fd25 100644 --- a/py2wasm/python.py +++ b/py2wasm/python.py @@ -72,7 +72,37 @@ class Visitor: return wasm.Import(module, name, intname, _parse_import_args(node.args)) - return wasm.Function(node.name, exported) + result = _parse_annotation(node.returns) + + statements = [ + self.visit_stmt(node, stmt) + for stmt in node.body + ] + + return wasm.Function(node.name, exported, result, statements) + + def visit_stmt(self, func: ast.FunctionDef, stmt: ast.stmt) -> wasm.Statement: + """ + Visits a statement node + """ + if isinstance(stmt, ast.Return): + return self.visit_Return(func, stmt) + + raise NotImplementedError + + def visit_Return(self, func: ast.FunctionDef, stmt: ast.Return) -> wasm.Statement: + """ + Visits a statement node + """ + assert isinstance(stmt.value, ast.Constant) + assert isinstance(stmt.value.value, int) + + return_type = _parse_annotation(func.returns) + + return wasm.Statement( + '{}.const'.format(return_type), + str(stmt.value.value) + ) def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]: """ diff --git a/py2wasm/wasm.py b/py2wasm/wasm.py index 0843a28..fd1b43a 100644 --- a/py2wasm/wasm.py +++ b/py2wasm/wasm.py @@ -2,7 +2,7 @@ Python classes for storing the representation of Web Assembly code """ -from typing import Iterable, List, Tuple +from typing import Iterable, List, Optional, Tuple class Import: """ @@ -49,12 +49,18 @@ class Function: """ Represents a Web Assembly function """ - def __init__(self, name: str, exported: bool = True) -> None: + def __init__( + self, + name: str, + exported: bool, + result: Optional[str], + statements: Iterable[Statement], + ) -> None: self.name = name self.exported = exported self.params: List[Tuple[str, str]] = [] - self.returns = None - self.statements: List[Statement] = [] + self.result = result + self.statements = [*statements] def generate(self) -> str: """ @@ -65,6 +71,9 @@ class Function: for nam, typ in self.params: header += ' (param ${} {})'.format(nam, typ) + if self.result: + header += ' (result {})'.format(self.result) + return '(func {}\n {})'.format( header, '\n '.join(x.generate() for x in self.statements), diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py new file mode 100644 index 0000000..99ce3ce --- /dev/null +++ b/tests/integration/test_simple.py @@ -0,0 +1,13 @@ +from .helpers import Suite + +def test_return(): + code_py = """ +@exported +def testEntry() -> i32: + return 13 +""" + + result = Suite(code_py, 'test_fib').run_code() + + assert 13 is result.returned_value + assert [] == result.log_int32_list