""" Code for parsing the source (python-alike) code """ from typing import Dict, Generator, List, Optional, Union, Tuple import ast from . import wasm # pylint: disable=C0103,R0201 StatementGenerator = Generator[wasm.Statement, None, None] WLocals = Dict[str, str] class Visitor: """ Class to visit a Python syntax tree Since we need to visit the whole tree, there's no point in subclassing the buildin visitor """ def visit_Module(self, node: ast.Module) -> wasm.Module: """ Visits a Python module, which results in a wasm Module """ assert not node.type_ignores module = wasm.Module() for stmt in node.body: if isinstance(stmt, ast.FunctionDef): wnode = self.visit_FunctionDef(module, stmt) if isinstance(wnode, wasm.Import): module.imports.append(wnode) else: module.functions.append(wnode) continue raise NotImplementedError(stmt) return module def visit_FunctionDef( self, module: wasm.Module, node: ast.FunctionDef, ) -> Union[wasm.Import, wasm.Function]: """ A Python function definition with the @external decorator is returned as import. Other functions are returned as (exported) functions. Nested / dynamicly created functions are not yet supported """ exported = False if node.decorator_list: assert 1 == len(node.decorator_list) decorator = node.decorator_list[0] if isinstance(decorator, ast.Name): assert 'exported' == decorator.id exported = True elif isinstance(decorator, ast.Call): call = decorator assert not node.type_comment assert 1 == len(node.body) assert isinstance(node.body[0], ast.Expr) assert isinstance(node.body[0].value, ast.Ellipsis) assert isinstance(call.func, ast.Name) assert 'imported' == call.func.id assert not call.keywords mod, name, intname = _parse_import_decorator(node.name, call.args) return wasm.Import(mod, name, intname, _parse_import_args(node.args)) else: raise NotImplementedError result = None if node.returns is None else _parse_annotation(node.returns) assert not node.args.vararg assert not node.args.kwonlyargs assert not node.args.kw_defaults assert not node.args.kwarg assert not node.args.defaults params = [ (a.arg, _parse_annotation(a.annotation), ) for a in [ *node.args.posonlyargs, *node.args.args, ] ] wlocals: WLocals = dict(params) statements: List[wasm.Statement] = [] for py_stmt in node.body: statements.extend( self.visit_stmt(module, node, wlocals, py_stmt) ) return wasm.Function(node.name, exported, params, result, statements) def visit_stmt( self, module: wasm.Module, func: ast.FunctionDef, wlocals: WLocals, stmt: ast.stmt, ) -> StatementGenerator: """ Visits a statement node within a function """ if isinstance(stmt, ast.Return): return self.visit_Return(module, func, wlocals, stmt) if isinstance(stmt, ast.Expr): assert isinstance(stmt.value, ast.Call) return self.visit_Call(module, wlocals, "None", stmt.value) raise NotImplementedError(stmt) def visit_Return( self, module: wasm.Module, func: ast.FunctionDef, wlocals: WLocals, stmt: ast.Return, ) -> StatementGenerator: """ Visits a statement node """ assert stmt.value is not None return_type = _parse_annotation(func.returns) return self.visit_expr(module, wlocals, return_type, stmt.value) def visit_expr( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.expr, ) -> StatementGenerator: """ Visit an expression node """ if isinstance(node, ast.BinOp): return self.visit_BinOp(module, wlocals, exp_type, node) if isinstance(node, ast.Call): return self.visit_Call(module, wlocals, exp_type, node) if isinstance(node, ast.Constant): return self.visit_Constant(exp_type, node) if isinstance(node, ast.Name): return self.visit_Name(wlocals, exp_type, node) raise NotImplementedError(node) def visit_BinOp( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.BinOp, ) -> StatementGenerator: """ Visits a BinOp node as (part of) an expression """ yield from self.visit_expr(module, wlocals, exp_type, node.left) yield from self.visit_expr(module, wlocals, exp_type, node.right) if isinstance(node.op, ast.Add): yield wasm.Statement('{}.add'.format(exp_type)) return raise NotImplementedError(node.op) def visit_Call( self, module: wasm.Module, wlocals: WLocals, exp_type: str, node: ast.Call, ) -> StatementGenerator: """ Visits a Call node as (part of) an expression """ assert isinstance(node.func, ast.Name) assert not node.keywords called_name = node.func.id called_func_list = [ x for x in module.functions if x.name == called_name ] if called_func_list: assert len(called_func_list) == 1 called_params = called_func_list[0].params called_result = called_func_list[0].result else: called_import_list = [ x for x in module.imports if x.name == node.func.id ] if called_import_list: assert len(called_import_list) == 1 called_params = called_import_list[0].params called_result = called_import_list[0].result else: assert 1 == len(called_func_list), \ 'Could not find function {}'.format(node.func.id) assert exp_type == called_result assert len(called_params) == len(node.args), \ '{}:{} Function {} requires {} arguments, but {} are supplied'.format( node.lineno, node.col_offset, called_name, len(called_params), len(node.args), ) for (_, exp_expr_type), expr in zip(called_params, node.args): yield from self.visit_expr(module, wlocals, exp_expr_type, expr) yield wasm.Statement( 'call', '${}'.format(called_name), ) def visit_Constant(self, exp_type: str, node: ast.Constant) -> StatementGenerator: """ Visits a Constant node as (part of) an expression """ if 'i32' == exp_type: assert isinstance(node.value, int) assert -2147483648 <= node.value <= 2147483647 yield wasm.Statement( '{}.const'.format(exp_type), str(node.value) ) return raise NotImplementedError(exp_type) def visit_Name(self, wlocals: WLocals, exp_type: str, node: ast.Name) -> StatementGenerator: """ Visits a Name node as (part of) an expression """ assert node.id in wlocals assert exp_type == wlocals[node.id] yield wasm.Statement( 'local.get', '${}'.format(node.id), ) def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]: """ Parses an @import decorator """ assert 0 < len(args) < 3 str_args = [ arg.value for arg in args if isinstance(arg, ast.Constant) and isinstance(arg.value, str) ] assert len(str_args) == len(args) module = str_args.pop(0) if str_args: name = str_args.pop(0) else: name = func_name return module, name, func_name def _parse_import_args(args: ast.arguments) -> List[Tuple[str, str]]: """ Parses the arguments for an @imported method """ assert not args.vararg assert not args.kwonlyargs assert not args.kw_defaults assert not args.kwarg assert not args.defaults # Maybe support this later on arg_list = [ *args.posonlyargs, *args.args, ] return [ (arg.arg, _parse_annotation(arg.annotation)) for arg in arg_list ] def _parse_annotation(ann: Optional[ast.expr]) -> str: """ Parses a typing annotation """ assert ann is not None, 'Web Assembly requires type annotations' assert isinstance(ann, ast.Name) result = ann.id assert result in ['i32'] return result