""" Code for parsing the source (python-alike) code """ from typing import Dict, Generator, List, Optional, Union, Tuple import ast from . import wasm # pylint: disable=C0103,R0201 StatementGenerator = Generator[wasm.Statement, None, None] WLocals = Dict[str, str] class Visitor: """ Class to visit a Python syntax tree Since we need to visit the whole tree, there's no point in subclassing the buildin visitor """ def visit_Module(self, node: ast.Module) -> wasm.Module: """ Visits a Python module, which results in a wasm Module """ assert not node.type_ignores module = wasm.Module() module.functions.append(wasm.Function( '___new_reference___', False, [ ('alloc_size', 'i32'), ], [ ('result', 'i32'), ], 'i32', [ wasm.Statement('i32.const', '0'), wasm.Statement('i32.const', '0'), wasm.Statement('i32.load'), wasm.Statement('local.tee', '$result', comment='Address for this call'), wasm.Statement('local.get', '$alloc_size'), wasm.Statement('i32.add'), wasm.Statement('i32.store', comment='Address for the next call'), wasm.Statement('local.get', '$result'), ] )) # Do a check first for all function definitions # to get their types. Otherwise you cannot call # a method that you haven't defined just yet, # even if it is in the same file function_body_map: Dict[ast.FunctionDef, wasm.Function] = {} for stmt in node.body: if isinstance(stmt, ast.FunctionDef): wnode = self.pre_visit_FunctionDef(module, stmt) if isinstance(wnode, wasm.Import): module.imports.append(wnode) else: module.functions.append(wnode) function_body_map[stmt] = wnode continue if isinstance(stmt, ast.ClassDef): wclass = self.pre_visit_ClassDef(module, stmt) module.classes.append(wclass) continue # No other pre visits to do for stmt in node.body: if isinstance(stmt, ast.FunctionDef): if stmt in function_body_map: self.parse_FunctionDef_body(module, function_body_map[stmt], stmt) # else: It's an import, no actual body to parse continue if isinstance(stmt, ast.ClassDef): continue raise NotImplementedError(stmt) return module def pre_visit_FunctionDef( self, module: wasm.Module, node: ast.FunctionDef, ) -> Union[wasm.Import, wasm.Function]: """ A Python function definition with the @external decorator is returned as import. Other functions are returned as (exported) functions. Nested / dynamicly created functions are not yet supported """ exported = False if node.decorator_list: assert 1 == len(node.decorator_list) decorator = node.decorator_list[0] if isinstance(decorator, ast.Name): assert 'exported' == decorator.id exported = True elif isinstance(decorator, ast.Call): call = decorator assert not node.type_comment assert 1 == len(node.body) assert isinstance(node.body[0], ast.Expr) assert isinstance(node.body[0].value, ast.Ellipsis) assert isinstance(call.func, ast.Name) assert 'imported' == call.func.id assert not call.keywords mod, name, intname = _parse_import_decorator(node.name, call.args) return wasm.Import(mod, name, intname, _parse_import_args(node.args)) else: raise NotImplementedError result = None if node.returns is None else _parse_annotation(node.returns) assert not node.args.vararg assert not node.args.kwonlyargs assert not node.args.kw_defaults assert not node.args.kwarg assert not node.args.defaults class_lookup = { x.name: x for x in module.classes } params = [] for arg in [*node.args.posonlyargs, *node.args.args]: if not isinstance(arg.annotation, ast.Name): raise NotImplementedError if arg.annotation.id in class_lookup: params.append((arg.arg, 'i32', )) else: params.append((arg.arg, _parse_annotation(arg.annotation), )) locals_ = [ ('___new_reference___addr', 'i32'), # For the ___new_reference__ method ] return wasm.Function(node.name, exported, params, locals_, result, []) def parse_FunctionDef_body( self, module: wasm.Module, func: wasm.Function, node: ast.FunctionDef, ) -> None: """ Parses the function body """ wlocals: WLocals = dict(func.params) statements: List[wasm.Statement] = [] for py_stmt in node.body: statements.extend( self.visit_stmt(module, node, wlocals, py_stmt) ) func.statements = statements def pre_visit_ClassDef( self, module: wasm.Module, node: ast.ClassDef, ) -> wasm.Class: """ TODO: Document this """ del module if node.bases or node.keywords or node.decorator_list: raise NotImplementedError members: List[wasm.ClassMember] = [] offset = 0 for stmt in node.body: if not isinstance(stmt, ast.AnnAssign): raise NotImplementedError if not isinstance(stmt.target, ast.Name): raise NotImplementedError if not isinstance(stmt.annotation, ast.Name): raise NotImplementedError if stmt.annotation.id != 'i32': raise NotImplementedError if stmt.value is None: default = None else: if not isinstance(stmt.value, ast.Constant): raise NotImplementedError if not isinstance(stmt.value.value, int): raise NotImplementedError default = wasm.Constant(stmt.value.value) member = wasm.ClassMember( stmt.target.id, stmt.annotation.id, offset, default ) members.append(member) offset += member.alloc_size() return wasm.Class(node.name, members) def visit_stmt( self, module: wasm.Module, func: ast.FunctionDef, wlocals: WLocals, stmt: ast.stmt, ) -> StatementGenerator: """ Visits a statement node within a function """ if isinstance(stmt, ast.Return): return self.visit_Return(module, func, wlocals, stmt) if isinstance(stmt, ast.Expr): assert isinstance(stmt.value, ast.Call) return self.visit_Call(module, wlocals, "None", stmt.value) if isinstance(stmt, ast.If): return self.visit_If(module, func, wlocals, stmt) raise NotImplementedError(stmt) def visit_Return( self, module: wasm.Module, func: ast.FunctionDef, wlocals: WLocals, stmt: ast.Return, ) -> StatementGenerator: """ Visits a Return node """ assert stmt.value is not None return_type = _parse_annotation(func.returns) yield from self.visit_expr(module, wlocals, return_type, stmt.value) yield wasm.Statement('return') def visit_If( self, module: wasm.Module, func: ast.FunctionDef, wlocals: WLocals, stmt: ast.If, ) -> StatementGenerator: """ Visits an If node """ yield from self.visit_expr( module, wlocals, 'bool', stmt.test, ) yield wasm.Statement('if') for py_stmt in stmt.body: yield from self.visit_stmt(module, func, wlocals, py_stmt) yield wasm.Statement('else') for py_stmt in stmt.orelse: yield from self.visit_stmt(module, func, wlocals, py_stmt) yield wasm.Statement('end') def visit_expr( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.expr, ) -> StatementGenerator: """ Visit an expression node """ if isinstance(node, ast.BinOp): return self.visit_BinOp(module, wlocals, exp_type, node) if isinstance(node, ast.UnaryOp): return self.visit_UnaryOp(module, wlocals, exp_type, node) if isinstance(node, ast.Compare): return self.visit_Compare(module, wlocals, exp_type, node) if isinstance(node, ast.Call): return self.visit_Call(module, wlocals, exp_type, node) if isinstance(node, ast.Constant): return self.visit_Constant(exp_type, node) if isinstance(node, ast.Name): return self.visit_Name(wlocals, exp_type, node) if isinstance(node, ast.Attribute): return self.visit_Attribute(module, wlocals, exp_type, node) raise NotImplementedError(node) def visit_UnaryOp( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.UnaryOp, ) -> StatementGenerator: """ Visits a UnaryOp node as (part of) an expression """ if isinstance(node.op, ast.UAdd): return self.visit_expr(module, wlocals, exp_type, node.operand) if isinstance(node.op, ast.USub): if not isinstance(node.operand, ast.Constant): raise NotImplementedError(node.operand) if not isinstance(node.operand.value, int): raise NotImplementedError(node.operand.value) return self.visit_Constant(exp_type, ast.Constant(-node.operand.value)) raise NotImplementedError(node.op) def visit_BinOp( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.BinOp, ) -> StatementGenerator: """ Visits a BinOp node as (part of) an expression """ yield from self.visit_expr(module, wlocals, exp_type, node.left) yield from self.visit_expr(module, wlocals, exp_type, node.right) if isinstance(node.op, ast.Add): yield wasm.Statement('{}.add'.format(exp_type)) return if isinstance(node.op, ast.Sub): yield wasm.Statement('{}.sub'.format(exp_type)) return raise NotImplementedError(node.op) def visit_Compare( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.Compare, ) -> StatementGenerator: """ Visits a Compare node as (part of) an expression """ assert 'bool' == exp_type if 1 != len(node.ops) or 1 != len(node.comparators): raise NotImplementedError yield from self.visit_expr(module, wlocals, 'i32', node.left) yield from self.visit_expr(module, wlocals, 'i32', node.comparators[0]) if isinstance(node.ops[0], ast.Lt): yield wasm.Statement('i32.lt_s') return if isinstance(node.ops[0], ast.Eq): yield wasm.Statement('i32.eq') return if isinstance(node.ops[0], ast.Gt): yield wasm.Statement('i32.gt_s') return raise NotImplementedError(node.ops) def visit_Call( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.Call, ) -> StatementGenerator: """ Visits a Call node as (part of) an expression """ assert isinstance(node.func, ast.Name) assert not node.keywords called_name = node.func.id search_list: List[Union[wasm.Function, wasm.Import, wasm.Class]] search_list = [ *module.functions, *module.imports, *module.classes, ] called_func_list = [ x for x in search_list if x.name == called_name ] assert 1 == len(called_func_list), \ 'Could not find function {}'.format(node.func.id) if isinstance(called_func_list[0], wasm.Class): return self.visit_Call_class(module, wlocals, exp_type, node, called_func_list[0]) return self.visit_Call_func(module, wlocals, exp_type, node, called_func_list[0]) def visit_Call_class( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.Call, cls: wasm.Class, ) -> StatementGenerator: """ Visits a Call node as (part of) an expression This instantiates the class """ # TODO: malloc call yield wasm.Statement('i32.const', str(cls.alloc_size())) yield wasm.Statement('call', '$___new_reference___') yield wasm.Statement('local.set', '$___new_reference___addr') for member, arg in zip(cls.members, node.args): if not isinstance(arg, ast.Constant): raise NotImplementedError yield wasm.Statement('local.get', '$___new_reference___addr') yield wasm.Statement(f'{member.type}.const', str(arg.value)) yield wasm.Statement(f'{member.type}.store', 'offset=' + str(member.offset)) yield wasm.Statement('local.get', '$___new_reference___addr') def visit_Call_func( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.Call, func: Union[wasm.Function, wasm.Import], ) -> StatementGenerator: """ Visits a Call node as (part of) an expression """ assert isinstance(node.func, ast.Name) called_name = node.func.id called_params = func.params called_result = func.result assert exp_type == called_result assert len(called_params) == len(node.args), \ '{}:{} Function {} requires {} arguments, but {} are supplied'.format( node.lineno, node.col_offset, called_name, len(called_params), len(node.args), ) for (_, exp_expr_type), expr in zip(called_params, node.args): yield from self.visit_expr(module, wlocals, exp_expr_type, expr) yield wasm.Statement( 'call', '${}'.format(called_name), ) def visit_Constant(self, exp_type: str, node: ast.Constant) -> StatementGenerator: """ Visits a Constant node as (part of) an expression """ if 'i32' == exp_type: assert isinstance(node.value, int) assert -2147483648 <= node.value <= 2147483647 yield wasm.Statement('i32.const', str(node.value)) return if 'i64' == exp_type: assert isinstance(node.value, int) assert -9223372036854775808 <= node.value <= 9223372036854775807 yield wasm.Statement('i64.const', str(node.value)) return if 'f32' == exp_type: assert isinstance(node.value, float) # TODO: Size check? yield wasm.Statement('f32.const', node.value.hex()) return if 'f64' == exp_type: assert isinstance(node.value, float) # TODO: Size check? yield wasm.Statement('f64.const', node.value.hex()) return raise NotImplementedError(exp_type) def visit_Attribute( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.Attribute, ) -> StatementGenerator: """ Visits an Attribute node as (part of) an expression """ if not isinstance(node.value, ast.Name): raise NotImplementedError if not isinstance(node.ctx, ast.Load): raise NotImplementedError cls_list = [ x for x in module.classes if x.name == 'Rectangle' # TODO: STUB, since we can't acces the type properly ] assert len(cls_list) == 1 cls = cls_list[0] member_list = [ x for x in cls.members if x.name == node.attr ] assert len(member_list) == 1 member = member_list[0] yield wasm.Statement('local.get', '$' + node.value.id) yield wasm.Statement(exp_type + '.load', 'offset=' + str(member.offset)) def visit_Name(self, wlocals: WLocals, exp_type: str, node: ast.Name) -> StatementGenerator: """ Visits a Name node as (part of) an expression """ assert node.id in wlocals if exp_type == 'i64' and wlocals[node.id] == 'i32': yield wasm.Statement('local.get', '${}'.format(node.id)) yield wasm.Statement('i64.extend_i32_s') return if exp_type == 'f64' and wlocals[node.id] == 'f32': yield wasm.Statement('local.get', '${}'.format(node.id)) yield wasm.Statement('f64.promote_f32') return assert exp_type == wlocals[node.id] yield wasm.Statement( 'local.get', '${}'.format(node.id), ) def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]: """ Parses an @import decorator """ assert 0 < len(args) < 3 str_args = [ arg.value for arg in args if isinstance(arg, ast.Constant) and isinstance(arg.value, str) ] assert len(str_args) == len(args) module = str_args.pop(0) if str_args: name = str_args.pop(0) else: name = func_name return module, name, func_name def _parse_import_args(args: ast.arguments) -> List[Tuple[str, str]]: """ Parses the arguments for an @imported method """ assert not args.vararg assert not args.kwonlyargs assert not args.kw_defaults assert not args.kwarg assert not args.defaults # Maybe support this later on arg_list = [ *args.posonlyargs, *args.args, ] return [ (arg.arg, _parse_annotation(arg.annotation)) for arg in arg_list ] def _parse_annotation(ann: Optional[ast.expr]) -> str: """ Parses a typing annotation """ assert ann is not None, 'Web Assembly requires type annotations' assert isinstance(ann, ast.Name) result = ann.id assert result in ['i32', 'i64', 'f32', 'f64'], result return result