diff --git a/py2wasm/python.py b/py2wasm/python.py index 77b53d2..6f5d36b 100644 --- a/py2wasm/python.py +++ b/py2wasm/python.py @@ -45,6 +45,11 @@ class Visitor: function_body_map[stmt] = wnode continue + if isinstance(stmt, ast.ClassDef): + wclass = self.pre_visit_ClassDef(module, stmt) + module.classes.append(wclass) + continue + # No other pre visits to do for stmt in node.body: @@ -54,6 +59,9 @@ class Visitor: # else: It's an import, no actual body to parse continue + if isinstance(stmt, ast.ClassDef): + continue + raise NotImplementedError(stmt) return module @@ -70,8 +78,6 @@ class Visitor: Nested / dynamicly created functions are not yet supported """ - del module - exported = False if node.decorator_list: @@ -108,13 +114,22 @@ class Visitor: assert not node.args.kwarg assert not node.args.defaults - params = [ - (a.arg, _parse_annotation(a.annotation), ) - for a in [ - *node.args.posonlyargs, - *node.args.args, - ] - ] + class_lookup = { + x.name: x + for x in module.classes + } + + params = [] + for arg in [*node.args.posonlyargs, *node.args.args]: + if not isinstance(arg.annotation, ast.Name): + raise NotImplementedError + + print(class_lookup) + print(arg.annotation.id) + if arg.annotation.id in class_lookup: + params.append((arg.arg, arg.annotation.id, )) + else: + params.append((arg.arg, _parse_annotation(arg.annotation), )) return wasm.Function(node.name, exported, params, result, []) @@ -137,6 +152,51 @@ class Visitor: func.statements = statements + def pre_visit_ClassDef( + self, + module: wasm.Module, + node: ast.ClassDef, + ) -> wasm.Class: + """ + TODO: Document this + """ + del module + + if node.bases or node.keywords or node.decorator_list: + raise NotImplementedError + + members: List[wasm.ClassMember] = [] + + for stmt in node.body: + if not isinstance(stmt, ast.AnnAssign): + raise NotImplementedError + + if not isinstance(stmt.target, ast.Name): + raise NotImplementedError + + if not isinstance(stmt.annotation, ast.Name): + raise NotImplementedError + + if stmt.annotation.id != 'i32': + raise NotImplementedError + + if stmt.value is None: + default = None + else: + if not isinstance(stmt.value, ast.Constant): + raise NotImplementedError + + if not isinstance(stmt.value.value, int): + raise NotImplementedError + + default = wasm.Constant(stmt.value.value) + + members.append(wasm.ClassMember( + stmt.target.id, stmt.annotation.id, default + )) + + return wasm.Class(node.name, members) + def visit_stmt( self, module: wasm.Module, @@ -235,6 +295,9 @@ class Visitor: if isinstance(node, ast.Name): return self.visit_Name(wlocals, exp_type, node) + if isinstance(node, ast.Attribute): + return [] # TODO + raise NotImplementedError(node) def visit_UnaryOp( @@ -331,10 +394,11 @@ class Visitor: called_name = node.func.id - search_list: List[Union[wasm.Function, wasm.Import]] + search_list: List[Union[wasm.Function, wasm.Import, wasm.Class]] search_list = [ *module.functions, *module.imports, + *module.classes, ] called_func_list = [ @@ -346,8 +410,15 @@ class Visitor: assert 1 == len(called_func_list), \ 'Could not find function {}'.format(node.func.id) - called_params = called_func_list[0].params - called_result = called_func_list[0].result + if isinstance(called_func_list[0], wasm.Class): + called_params = [ + (x.name, x.type, ) + for x in called_func_list[0].members + ] + called_result: Optional[str] = called_func_list[0].name + else: + called_params = called_func_list[0].params + called_result = called_func_list[0].result assert exp_type == called_result @@ -471,5 +542,5 @@ def _parse_annotation(ann: Optional[ast.expr]) -> str: assert ann is not None, 'Web Assembly requires type annotations' assert isinstance(ann, ast.Name) result = ann.id - assert result in ['i32', 'i64', 'f32', 'f64'] + assert result in ['i32', 'i64', 'f32', 'f64'], result return result diff --git a/py2wasm/wasm.py b/py2wasm/wasm.py index 24d1f46..68300fc 100644 --- a/py2wasm/wasm.py +++ b/py2wasm/wasm.py @@ -2,7 +2,7 @@ Python classes for storing the representation of Web Assembly code """ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union Param = Tuple[str, str] @@ -83,6 +83,30 @@ class Function: '\n '.join(x.generate() for x in self.statements), ) +class Constant: + """ + TODO + """ + def __init__(self, value: Union[None, bool, int, float]) -> None: + self.value = value + +class ClassMember: + """ + Represents a Web Assembly class member + """ + def __init__(self, name: str, type_: str, default: Optional[Constant]) -> None: + self.name = name + self.type = type_ + self.default = default + +class Class: + """ + Represents a Web Assembly class + """ + def __init__(self, name: str, members: List[ClassMember]) -> None: + self.name = name + self.members = members + class Module: """ Represents a Web Assembly module @@ -90,6 +114,7 @@ class Module: def __init__(self) -> None: self.imports: List[Import] = [] self.functions: List[Function] = [] + self.classes: List[Class] = [] def generate(self) -> str: """ diff --git a/tests/integration/test_fib.py b/tests/integration/test_fib.py index b27297d..bd701c9 100644 --- a/tests/integration/test_fib.py +++ b/tests/integration/test_fib.py @@ -28,3 +28,4 @@ def testEntry() -> i32: result = Suite(code_py, 'test_fib').run_code() assert 102334155 == result.returned_value + assert False diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index 7650c6c..c399e86 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -234,3 +234,40 @@ def helper(left: i32, right: i32) -> i32: assert 7 == result.returned_value assert [] == result.log_int32_list + +@pytest.mark.integration_test +def test_assign(): + code_py = """ + +@exported +def testEntry() -> i32: + a: i32 = 8947 + return a +""" + + result = Suite(code_py, 'test_call').run_code() + + assert 8947 == result.returned_value + assert [] == result.log_int32_list + +@pytest.mark.integration_test +def test_struct(): + code_py = """ + +class Rectangle: + height: i32 + width: i32 + border: i32 # = 5 + +@exported +def testEntry() -> i32: + return helper(Rectangle(100, 150, 2)) + +def helper(shape: Rectangle) -> i32: + return shape.height + shape.width + shape.border +""" + + result = Suite(code_py, 'test_call').run_code() + + assert 100 == result.returned_value + assert [] == result.log_int32_list