diff --git a/py2wasm/python.py b/py2wasm/python.py index 8e5876c..41f7978 100644 --- a/py2wasm/python.py +++ b/py2wasm/python.py @@ -12,7 +12,7 @@ from . import wasm StatementGenerator = Generator[wasm.Statement, None, None] -WLocals = Dict[str, str] +WLocals = Dict[str, wasm.OurType] class Visitor: """ @@ -21,6 +21,29 @@ class Visitor: Since we need to visit the whole tree, there's no point in subclassing the buildin visitor """ + + def __init__(self) -> None: + self._type_map: WLocals = { + 'None': wasm.OurTypeNone(), + 'bool': wasm.OurTypeBool(), + 'i32': wasm.OurTypeInt32(), + 'i64': wasm.OurTypeInt64(), + 'f32': wasm.OurTypeFloat32(), + 'f64': wasm.OurTypeFloat64(), + } + + def _get_type(self, name: Union[None, str, ast.expr]) -> wasm.OurType: + if name is None: + name = 'None' + + if isinstance(name, ast.expr): + if isinstance(name, ast.Name): + name = name.id + else: + raise NotImplementedError(f'_get_type(ast.{type(name).__name__})') + + return self._type_map[name] + def visit_Module(self, node: ast.Module) -> wasm.Module: """ Visits a Python module, which results in a wasm Module @@ -33,12 +56,12 @@ class Visitor: '___new_reference___', False, [ - ('alloc_size', 'i32'), + ('alloc_size', self._get_type('i32')), ], [ - ('result', 'i32'), + ('result', self._get_type('i32')), ], - 'i32', + self._get_type('i32'), [ wasm.Statement('i32.const', '0'), wasm.Statement('i32.const', '0'), @@ -69,7 +92,7 @@ class Visitor: if isinstance(stmt, ast.ClassDef): wclass = self.pre_visit_ClassDef(module, stmt) - module.classes.append(wclass) + self._type_map[wclass.name] = wclass continue # No other pre visits to do @@ -122,13 +145,13 @@ class Visitor: assert not call.keywords - mod, name, intname = _parse_import_decorator(node.name, call.args) + mod, name, intname = self._parse_import_decorator(node.name, call.args) - return wasm.Import(mod, name, intname, _parse_import_args(node.args)) + return wasm.Import(mod, name, intname, self._parse_import_args(node.args)) else: raise NotImplementedError - result = None if node.returns is None else _parse_annotation(node.returns) + result = None if node.returns is None else self._get_type(node.returns) assert not node.args.vararg assert not node.args.kwonlyargs @@ -136,23 +159,12 @@ class Visitor: assert not node.args.kwarg assert not node.args.defaults - class_lookup = { - x.name: x - for x in module.classes - } - - params = [] + params: List[wasm.Param] = [] for arg in [*node.args.posonlyargs, *node.args.args]: - if not isinstance(arg.annotation, ast.Name): - raise NotImplementedError + params.append((arg.arg, self._get_type(arg.annotation), )) - if arg.annotation.id in class_lookup: - params.append((arg.arg, 'i32', )) - else: - params.append((arg.arg, _parse_annotation(arg.annotation), )) - - locals_ = [ - ('___new_reference___addr', 'i32'), # For the ___new_reference__ method + locals_: List[wasm.Param] = [ + ('___new_reference___addr', self._get_type('i32')), # For the ___new_reference__ method ] return wasm.Function(node.name, exported, params, locals_, result, []) @@ -180,7 +192,7 @@ class Visitor: self, module: wasm.Module, node: ast.ClassDef, - ) -> wasm.Class: + ) -> wasm.OurTypeClass: """ TODO: Document this """ @@ -200,7 +212,7 @@ class Visitor: raise NotImplementedError if not isinstance(stmt.annotation, ast.Name): - raise NotImplementedError + raise NotImplementedError('Cannot recurse classes yet') if stmt.annotation.id != 'i32': raise NotImplementedError @@ -217,13 +229,13 @@ class Visitor: default = wasm.Constant(stmt.value.value) member = wasm.ClassMember( - stmt.target.id, stmt.annotation.id, offset, default + stmt.target.id, self._get_type(stmt.annotation), offset, default ) members.append(member) - offset += member.alloc_size() + offset += member.type.alloc_size() - return wasm.Class(node.name, members) + return wasm.OurTypeClass(node.name, members) def visit_stmt( self, @@ -241,7 +253,7 @@ class Visitor: if isinstance(stmt, ast.Expr): assert isinstance(stmt.value, ast.Call) - return self.visit_Call(module, wlocals, "None", stmt.value) + return self.visit_Call(module, wlocals, self._get_type('None'), stmt.value) if isinstance(stmt, ast.If): return self.visit_If(module, func, wlocals, stmt) @@ -261,7 +273,7 @@ class Visitor: assert stmt.value is not None - return_type = _parse_annotation(func.returns) + return_type = self._get_type(func.returns) yield from self.visit_expr(module, wlocals, return_type, stmt.value) yield wasm.Statement('return') @@ -279,7 +291,7 @@ class Visitor: yield from self.visit_expr( module, wlocals, - 'bool', + self._get_type('bool'), stmt.test, ) @@ -299,7 +311,7 @@ class Visitor: self, module: wasm.Module, wlocals: WLocals, - exp_type: str, + exp_type: wasm.OurType, node: ast.expr, ) -> StatementGenerator: """ @@ -332,7 +344,7 @@ class Visitor: self, module: wasm.Module, wlocals: WLocals, - exp_type: str, + exp_type: wasm.OurType, node: ast.UnaryOp, ) -> StatementGenerator: """ @@ -356,7 +368,7 @@ class Visitor: self, module: wasm.Module, wlocals: WLocals, - exp_type: str, + exp_type: wasm.OurType, node: ast.BinOp, ) -> StatementGenerator: """ @@ -366,11 +378,11 @@ class Visitor: yield from self.visit_expr(module, wlocals, exp_type, node.right) if isinstance(node.op, ast.Add): - yield wasm.Statement('{}.add'.format(exp_type)) + yield wasm.Statement('{}.add'.format(exp_type.to_wasm())) return if isinstance(node.op, ast.Sub): - yield wasm.Statement('{}.sub'.format(exp_type)) + yield wasm.Statement('{}.sub'.format(exp_type.to_wasm())) return raise NotImplementedError(node.op) @@ -379,19 +391,19 @@ class Visitor: self, module: wasm.Module, wlocals: WLocals, - exp_type: str, + exp_type: wasm.OurType, node: ast.Compare, ) -> StatementGenerator: """ Visits a Compare node as (part of) an expression """ - assert 'bool' == exp_type + assert isinstance(exp_type, wasm.OurTypeBool) if 1 != len(node.ops) or 1 != len(node.comparators): raise NotImplementedError - yield from self.visit_expr(module, wlocals, 'i32', node.left) - yield from self.visit_expr(module, wlocals, 'i32', node.comparators[0]) + yield from self.visit_expr(module, wlocals, self._get_type('i32'), node.left) + yield from self.visit_expr(module, wlocals, self._get_type('i32'), node.comparators[0]) if isinstance(node.ops[0], ast.Lt): yield wasm.Statement('i32.lt_s') @@ -411,7 +423,7 @@ class Visitor: self, module: wasm.Module, wlocals: WLocals, - exp_type: str, + exp_type: wasm.OurType, node: ast.Call, ) -> StatementGenerator: """ @@ -422,11 +434,15 @@ class Visitor: called_name = node.func.id - search_list: List[Union[wasm.Function, wasm.Import, wasm.Class]] + if called_name in self._type_map: + klass = self._type_map[called_name] + if isinstance(klass, wasm.OurTypeClass): + return self.visit_Call_class(module, wlocals, exp_type, node, klass) + + search_list: List[Union[wasm.Function, wasm.Import]] search_list = [ *module.functions, *module.imports, - *module.classes, ] called_func_list = [ @@ -438,18 +454,15 @@ class Visitor: assert 1 == len(called_func_list), \ 'Could not find function {}'.format(node.func.id) - if isinstance(called_func_list[0], wasm.Class): - return self.visit_Call_class(module, wlocals, exp_type, node, called_func_list[0]) - return self.visit_Call_func(module, wlocals, exp_type, node, called_func_list[0]) def visit_Call_class( self, module: wasm.Module, wlocals: WLocals, - exp_type: str, + exp_type: wasm.OurType, node: ast.Call, - cls: wasm.Class, + cls: wasm.OurTypeClass, ) -> StatementGenerator: """ Visits a Call node as (part of) an expression @@ -469,8 +482,8 @@ class Visitor: raise NotImplementedError yield wasm.Statement('local.get', '$___new_reference___addr') - yield wasm.Statement(f'{member.type}.const', str(arg.value)) - yield wasm.Statement(f'{member.type}.store', 'offset=' + str(member.offset)) + yield wasm.Statement(f'{member.type.to_wasm()}.const', str(arg.value)) + yield wasm.Statement(f'{member.type.to_wasm()}.store', 'offset=' + str(member.offset)) yield wasm.Statement('local.get', '$___new_reference___addr') @@ -478,7 +491,7 @@ class Visitor: self, module: wasm.Module, wlocals: WLocals, - exp_type: str, + exp_type: wasm.OurType, node: ast.Call, func: Union[wasm.Function, wasm.Import], ) -> StatementGenerator: @@ -491,7 +504,7 @@ class Visitor: called_params = func.params called_result = func.result - assert exp_type == called_result + assert exp_type == called_result, 'Function does not match expected type' assert len(called_params) == len(node.args), \ '{}:{} Function {} requires {} arguments, but {} are supplied'.format( @@ -507,32 +520,32 @@ class Visitor: '${}'.format(called_name), ) - def visit_Constant(self, exp_type: str, node: ast.Constant) -> StatementGenerator: + def visit_Constant(self, exp_type: wasm.OurType, node: ast.Constant) -> StatementGenerator: """ Visits a Constant node as (part of) an expression """ - if 'i32' == exp_type: + if isinstance(exp_type, wasm.OurTypeInt32): assert isinstance(node.value, int) assert -2147483648 <= node.value <= 2147483647 yield wasm.Statement('i32.const', str(node.value)) return - if 'i64' == exp_type: + if isinstance(exp_type, wasm.OurTypeInt64): assert isinstance(node.value, int) assert -9223372036854775808 <= node.value <= 9223372036854775807 yield wasm.Statement('i64.const', str(node.value)) return - if 'f32' == exp_type: + if isinstance(exp_type, wasm.OurTypeFloat32): assert isinstance(node.value, float) # TODO: Size check? yield wasm.Statement('f32.const', node.value.hex()) return - if 'f64' == exp_type: + if isinstance(exp_type, wasm.OurTypeFloat64): assert isinstance(node.value, float) # TODO: Size check? @@ -545,27 +558,21 @@ class Visitor: self, module: wasm.Module, wlocals: WLocals, - exp_type: str, + exp_type: wasm.OurType, node: ast.Attribute, ) -> StatementGenerator: """ Visits an Attribute node as (part of) an expression """ - if not isinstance(node.value, ast.Name): raise NotImplementedError if not isinstance(node.ctx, ast.Load): raise NotImplementedError - cls_list = [ - x - for x in module.classes - if x.name == 'Rectangle' # TODO: STUB, since we can't acces the type properly - ] + cls = wlocals[node.value.id] - assert len(cls_list) == 1 - cls = cls_list[0] + assert isinstance(cls, wasm.OurTypeClass), f'Cannot take property of {cls}' member_list = [ x @@ -577,79 +584,71 @@ class Visitor: member = member_list[0] yield wasm.Statement('local.get', '$' + node.value.id) - yield wasm.Statement(exp_type + '.load', 'offset=' + str(member.offset)) + yield wasm.Statement(exp_type.to_wasm() + '.load', 'offset=' + str(member.offset)) - def visit_Name(self, wlocals: WLocals, exp_type: str, node: ast.Name) -> StatementGenerator: + def visit_Name(self, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Name) -> StatementGenerator: """ Visits a Name node as (part of) an expression """ assert node.id in wlocals - if exp_type == 'i64' and wlocals[node.id] == 'i32': + if (isinstance(exp_type, wasm.OurTypeInt64) + and isinstance(wlocals[node.id], wasm.OurTypeInt32)): yield wasm.Statement('local.get', '${}'.format(node.id)) yield wasm.Statement('i64.extend_i32_s') return - if exp_type == 'f64' and wlocals[node.id] == 'f32': + if (isinstance(exp_type, wasm.OurTypeFloat64) + and isinstance(wlocals[node.id], wasm.OurTypeFloat32)): yield wasm.Statement('local.get', '${}'.format(node.id)) yield wasm.Statement('f64.promote_f32') return - assert exp_type == wlocals[node.id] + assert exp_type == wlocals[node.id], (exp_type, wlocals[node.id], ) yield wasm.Statement( 'local.get', '${}'.format(node.id), ) -def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]: - """ - Parses an @import decorator - """ + def _parse_import_decorator(self, func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]: + """ + Parses an @import decorator + """ - assert 0 < len(args) < 3 - str_args = [ - arg.value - for arg in args - if isinstance(arg, ast.Constant) - and isinstance(arg.value, str) - ] - assert len(str_args) == len(args) + assert 0 < len(args) < 3 + str_args = [ + arg.value + for arg in args + if isinstance(arg, ast.Constant) + and isinstance(arg.value, str) + ] + assert len(str_args) == len(args) - module = str_args.pop(0) - if str_args: - name = str_args.pop(0) - else: - name = func_name + module = str_args.pop(0) + if str_args: + name = str_args.pop(0) + else: + name = func_name - return module, name, func_name + return module, name, func_name -def _parse_import_args(args: ast.arguments) -> List[Tuple[str, str]]: - """ - Parses the arguments for an @imported method - """ + def _parse_import_args(self, args: ast.arguments) -> List[wasm.Param]: + """ + Parses the arguments for an @imported method + """ - assert not args.vararg - assert not args.kwonlyargs - assert not args.kw_defaults - assert not args.kwarg - assert not args.defaults # Maybe support this later on + assert not args.vararg + assert not args.kwonlyargs + assert not args.kw_defaults + assert not args.kwarg + assert not args.defaults # Maybe support this later on - arg_list = [ - *args.posonlyargs, - *args.args, - ] + arg_list = [ + *args.posonlyargs, + *args.args, + ] - return [ - (arg.arg, _parse_annotation(arg.annotation)) - for arg in arg_list - ] - -def _parse_annotation(ann: Optional[ast.expr]) -> str: - """ - Parses a typing annotation - """ - assert ann is not None, 'Web Assembly requires type annotations' - assert isinstance(ann, ast.Name) - result = ann.id - assert result in ['i32', 'i64', 'f32', 'f64'], result - return result + return [ + (arg.arg, self._get_type(arg.annotation)) + for arg in arg_list + ] diff --git a/py2wasm/wasm.py b/py2wasm/wasm.py index b025617..1939301 100644 --- a/py2wasm/wasm.py +++ b/py2wasm/wasm.py @@ -2,9 +2,90 @@ Python classes for storing the representation of Web Assembly code """ -from typing import Iterable, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union -Param = Tuple[str, str] +### +### This part is more intermediate code +### + +class OurType: + def to_wasm(self) -> str: + raise NotImplementedError + + def alloc_size(self) -> int: + raise NotImplementedError + +class OurTypeNone(OurType): + pass + +class OurTypeBool(OurType): + pass + +class OurTypeInt32(OurType): + def to_wasm(self) -> str: + return 'i32' + + def alloc_size(self) -> int: + return 4 + +class OurTypeInt64(OurType): + def to_wasm(self) -> str: + return 'i64' + + def alloc_size(self) -> int: + return 8 + +class OurTypeFloat32(OurType): + def to_wasm(self) -> str: + return 'f32' + + def alloc_size(self) -> int: + return 4 + +class OurTypeFloat64(OurType): + def to_wasm(self) -> str: + return 'f64' + + def alloc_size(self) -> int: + return 8 + +class Constant: + """ + TODO + """ + def __init__(self, value: Union[None, bool, int, float]) -> None: + self.value = value + +class ClassMember: + """ + Represents a class member + """ + def __init__(self, name: str, type_: OurType, offset: int, default: Optional[Constant]) -> None: + self.name = name + self.type = type_ + self.offset = offset + self.default = default + +class OurTypeClass(OurType): + """ + Represents a class + """ + def __init__(self, name: str, members: List[ClassMember]) -> None: + self.name = name + self.members = members + + def to_wasm(self) -> str: + return 'i32' # WASM uses 32 bit pointers + + def alloc_size(self) -> int: + return sum(x.type.alloc_size() for x in self.members) + + +Param = Tuple[str, OurType] + +### +## This part is more actual web assembly +### class Import: """ @@ -62,7 +143,7 @@ class Function: exported: bool, params: Iterable[Param], locals_: Iterable[Param], - result: Optional[str], + result: Optional[OurType], statements: Iterable[Statement], ) -> None: self.name = name @@ -79,55 +160,19 @@ class Function: header = ('(export "{}")' if self.exported else '${}').format(self.name) for nam, typ in self.params: - header += f' (param ${nam} {typ})' + header += f' (param ${nam} {typ.to_wasm()})' if self.result: - header += f' (result {self.result})' + header += f' (result {self.result.to_wasm()})' for nam, typ in self.locals: - header += f' (local ${nam} {typ})' + header += f' (local ${nam} {typ.to_wasm()})' return '(func {}\n {}\n )'.format( header, '\n '.join(x.generate() for x in self.statements), ) -class Constant: - """ - TODO - """ - def __init__(self, value: Union[None, bool, int, float]) -> None: - self.value = value - -class ClassMember: - """ - Represents a Web Assembly class member - """ - def __init__(self, name: str, type_: str, offset: int, default: Optional[Constant]) -> None: - self.name = name - self.type = type_ - self.offset = offset - self.default = default - - def alloc_size(self) -> int: - SIZE_MAP = { - 'i32': 4, - 'i64': 4, - } - - return SIZE_MAP[self.type] - -class Class: - """ - Represents a Web Assembly class - """ - def __init__(self, name: str, members: List[ClassMember]) -> None: - self.name = name - self.members = members - - def alloc_size(self) -> int: - return sum(x.alloc_size() for x in self.members) - class Module: """ Represents a Web Assembly module @@ -135,7 +180,6 @@ class Module: def __init__(self) -> None: self.imports: List[Import] = [] self.functions: List[Function] = [] - self.classes: List[Class] = [] def generate(self) -> str: """ diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index ca70345..dca3291 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -251,6 +251,26 @@ def testEntry() -> i32: assert 8947 == result.returned_value assert [] == result.log_int32_list +@pytest.mark.integration_test +def test_struct_0(): + code_py = """ + +class CheckedValue: + value: i32 + +@exported +def testEntry() -> i32: + return helper(CheckedValue(2345)) + +def helper(cv: CheckedValue) -> i32: + return cv.value +""" + + result = Suite(code_py, 'test_call').run_code() + + assert 2345 == result.returned_value + assert [] == result.log_int32_list + @pytest.mark.integration_test def test_struct_1(): code_py = """