""" Code for parsing the source (python-alike) code """ from typing import Dict, Generator, List, Optional, Union, Tuple import ast from . import wasm # pylint: disable=C0103,R0201 StatementGenerator = Generator[wasm.Statement, None, None] WLocals = Dict[str, wasm.OurType] class Visitor: """ Class to visit a Python syntax tree Since we need to visit the whole tree, there's no point in subclassing the buildin visitor """ def __init__(self) -> None: self._type_map: WLocals = { 'None': wasm.OurTypeNone(), 'bool': wasm.OurTypeBool(), 'i32': wasm.OurTypeInt32(), 'i64': wasm.OurTypeInt64(), 'f32': wasm.OurTypeFloat32(), 'f64': wasm.OurTypeFloat64(), } def _get_type(self, name: Union[None, str, ast.expr]) -> wasm.OurType: if name is None: name = 'None' if isinstance(name, ast.expr): if isinstance(name, ast.Name): name = name.id elif isinstance(name, ast.Subscript) and isinstance(name.value, ast.Name) and name.value.id == 'Tuple': assert isinstance(name.ctx, ast.Load) assert isinstance(name.slice, ast.Index) assert isinstance(name.slice.value, ast.Tuple) args: List[wasm.TupleMember] = [] offset = 0 for name_arg in name.slice.value.elts: arg = wasm.TupleMember( self._get_type(name_arg), offset ) args.append(arg) offset += arg.type.alloc_size() return wasm.OurTypeTuple(args) else: raise NotImplementedError(f'_get_type(ast.{type(name).__name__})') return self._type_map[name] def visit_Module(self, node: ast.Module) -> wasm.Module: """ Visits a Python module, which results in a wasm Module """ assert not node.type_ignores module = wasm.Module() module.functions.append(wasm.Function( '___new_reference___', False, [ ('alloc_size', self._get_type('i32')), ], [ ('result', self._get_type('i32')), ], self._get_type('i32'), [ wasm.Statement('i32.const', '0'), wasm.Statement('i32.const', '0'), wasm.Statement('i32.load'), wasm.Statement('local.tee', '$result', comment='Address for this call'), wasm.Statement('local.get', '$alloc_size'), wasm.Statement('i32.add'), wasm.Statement('i32.store', comment='Address for the next call'), wasm.Statement('local.get', '$result'), ] )) # We create functions for these special instructions # We'll let a future inline optimizer clean this up for us # TODO: Either a) make a sqrt32 and a sqrt64; # or b) decide to make the whole language auto-cast module.functions.append(wasm.Function( 'sqrt', False, [ ('z', self._type_map['f32']), ], [], self._type_map['f32'], [ wasm.Statement('local.get', '$z'), wasm.Statement('f32.sqrt'), ] )) # Do a check first for all function definitions # to get their types. Otherwise you cannot call # a method that you haven't defined just yet, # even if it is in the same file function_body_map: Dict[ast.FunctionDef, wasm.Function] = {} for stmt in node.body: if isinstance(stmt, ast.FunctionDef): wnode = self.pre_visit_FunctionDef(module, stmt) if isinstance(wnode, wasm.Import): module.imports.append(wnode) else: module.functions.append(wnode) function_body_map[stmt] = wnode continue if isinstance(stmt, ast.ClassDef): wclass = self.pre_visit_ClassDef(module, stmt) self._type_map[wclass.name] = wclass continue # No other pre visits to do for stmt in node.body: if isinstance(stmt, ast.FunctionDef): if stmt in function_body_map: self.parse_FunctionDef_body(module, function_body_map[stmt], stmt) # else: It's an import, no actual body to parse continue if isinstance(stmt, ast.ClassDef): continue raise NotImplementedError(stmt) return module def pre_visit_FunctionDef( self, module: wasm.Module, node: ast.FunctionDef, ) -> Union[wasm.Import, wasm.Function]: """ A Python function definition with the @external decorator is returned as import. Other functions are returned as (exported) functions. Nested / dynamicly created functions are not yet supported """ exported = False if node.decorator_list: assert 1 == len(node.decorator_list) decorator = node.decorator_list[0] if isinstance(decorator, ast.Name): assert 'exported' == decorator.id exported = True elif isinstance(decorator, ast.Call): call = decorator assert not node.type_comment assert 1 == len(node.body) assert isinstance(node.body[0], ast.Expr) assert isinstance(node.body[0].value, ast.Ellipsis) assert isinstance(call.func, ast.Name) assert 'imported' == call.func.id assert not call.keywords mod, name, intname = self._parse_import_decorator(node.name, call.args) return wasm.Import(mod, name, intname, self._parse_import_args(node.args)) else: raise NotImplementedError result = None if node.returns is None else self._get_type(node.returns) assert not node.args.vararg assert not node.args.kwonlyargs assert not node.args.kw_defaults assert not node.args.kwarg assert not node.args.defaults params: List[wasm.Param] = [] for arg in [*node.args.posonlyargs, *node.args.args]: params.append((arg.arg, self._get_type(arg.annotation), )) locals_: List[wasm.Param] = [ ('___new_reference___addr', self._get_type('i32')), # For the ___new_reference__ method ] return wasm.Function(node.name, exported, params, locals_, result, []) def parse_FunctionDef_body( self, module: wasm.Module, func: wasm.Function, node: ast.FunctionDef, ) -> None: """ Parses the function body """ wlocals: WLocals = dict(func.params) statements: List[wasm.Statement] = [] for py_stmt in node.body: statements.extend( self.visit_stmt(module, node, wlocals, py_stmt) ) func.statements = statements def pre_visit_ClassDef( self, module: wasm.Module, node: ast.ClassDef, ) -> wasm.OurTypeClass: """ TODO: Document this """ del module if node.bases or node.keywords or node.decorator_list: raise NotImplementedError members: List[wasm.ClassMember] = [] offset = 0 for stmt in node.body: if not isinstance(stmt, ast.AnnAssign): raise NotImplementedError if not isinstance(stmt.target, ast.Name): raise NotImplementedError if not isinstance(stmt.annotation, ast.Name): raise NotImplementedError('Cannot recurse classes yet') if stmt.annotation.id != 'i32': raise NotImplementedError if stmt.value is None: default = None else: if not isinstance(stmt.value, ast.Constant): raise NotImplementedError if not isinstance(stmt.value.value, int): raise NotImplementedError default = wasm.Constant(stmt.value.value) member = wasm.ClassMember( stmt.target.id, self._get_type(stmt.annotation), offset, default ) members.append(member) offset += member.type.alloc_size() return wasm.OurTypeClass(node.name, members) def visit_stmt( self, module: wasm.Module, func: ast.FunctionDef, wlocals: WLocals, stmt: ast.stmt, ) -> StatementGenerator: """ Visits a statement node within a function """ if isinstance(stmt, ast.Return): return self.visit_Return(module, func, wlocals, stmt) if isinstance(stmt, ast.Expr): assert isinstance(stmt.value, ast.Call) return self.visit_Call(module, wlocals, self._get_type('None'), stmt.value) if isinstance(stmt, ast.If): return self.visit_If(module, func, wlocals, stmt) raise NotImplementedError(stmt) def visit_Return( self, module: wasm.Module, func: ast.FunctionDef, wlocals: WLocals, stmt: ast.Return, ) -> StatementGenerator: """ Visits a Return node """ assert stmt.value is not None return_type = self._get_type(func.returns) yield from self.visit_expr(module, wlocals, return_type, stmt.value) yield wasm.Statement('return') def visit_If( self, module: wasm.Module, func: ast.FunctionDef, wlocals: WLocals, stmt: ast.If, ) -> StatementGenerator: """ Visits an If node """ yield from self.visit_expr( module, wlocals, self._get_type('bool'), stmt.test, ) yield wasm.Statement('if') for py_stmt in stmt.body: yield from self.visit_stmt(module, func, wlocals, py_stmt) yield wasm.Statement('else') for py_stmt in stmt.orelse: yield from self.visit_stmt(module, func, wlocals, py_stmt) yield wasm.Statement('end') def visit_expr( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.expr, ) -> StatementGenerator: """ Visit an expression node """ if isinstance(node, ast.BinOp): return self.visit_BinOp(module, wlocals, exp_type, node) if isinstance(node, ast.UnaryOp): return self.visit_UnaryOp(module, wlocals, exp_type, node) if isinstance(node, ast.Compare): return self.visit_Compare(module, wlocals, exp_type, node) if isinstance(node, ast.Call): return self.visit_Call(module, wlocals, exp_type, node) if isinstance(node, ast.Constant): return self.visit_Constant(exp_type, node) if isinstance(node, ast.Attribute): return self.visit_Attribute(module, wlocals, exp_type, node) if isinstance(node, ast.Subscript): return self.visit_Subscript(module, wlocals, exp_type, node) if isinstance(node, ast.Name): return self.visit_Name(wlocals, exp_type, node) if isinstance(node, ast.Tuple): return self.visit_Tuple(wlocals, exp_type, node) raise NotImplementedError(node) def visit_UnaryOp( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.UnaryOp, ) -> StatementGenerator: """ Visits a UnaryOp node as (part of) an expression """ if isinstance(node.op, ast.UAdd): return self.visit_expr(module, wlocals, exp_type, node.operand) if isinstance(node.op, ast.USub): if not isinstance(node.operand, ast.Constant): raise NotImplementedError(node.operand) if not isinstance(node.operand.value, int): raise NotImplementedError(node.operand.value) return self.visit_Constant(exp_type, ast.Constant(-node.operand.value)) raise NotImplementedError(node.op) def visit_BinOp( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.BinOp, ) -> StatementGenerator: """ Visits a BinOp node as (part of) an expression """ yield from self.visit_expr(module, wlocals, exp_type, node.left) yield from self.visit_expr(module, wlocals, exp_type, node.right) if isinstance(node.op, ast.Add): yield wasm.Statement('{}.add'.format(exp_type.to_wasm())) return if isinstance(node.op, ast.Mult): yield wasm.Statement('{}.mul'.format(exp_type.to_wasm())) return if isinstance(node.op, ast.Sub): yield wasm.Statement('{}.sub'.format(exp_type.to_wasm())) return raise NotImplementedError(node.op) def visit_Compare( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Compare, ) -> StatementGenerator: """ Visits a Compare node as (part of) an expression """ assert isinstance(exp_type, wasm.OurTypeBool) if 1 != len(node.ops) or 1 != len(node.comparators): raise NotImplementedError yield from self.visit_expr(module, wlocals, self._get_type('i32'), node.left) yield from self.visit_expr(module, wlocals, self._get_type('i32'), node.comparators[0]) if isinstance(node.ops[0], ast.Lt): yield wasm.Statement('i32.lt_s') return if isinstance(node.ops[0], ast.Eq): yield wasm.Statement('i32.eq') return if isinstance(node.ops[0], ast.Gt): yield wasm.Statement('i32.gt_s') return raise NotImplementedError(node.ops) def visit_Call( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Call, ) -> StatementGenerator: """ Visits a Call node as (part of) an expression """ assert isinstance(node.func, ast.Name) assert not node.keywords called_name = node.func.id if called_name in self._type_map: klass = self._type_map[called_name] if isinstance(klass, wasm.OurTypeClass): return self.visit_Call_class(module, wlocals, exp_type, node, klass) search_list: List[Union[wasm.Function, wasm.Import]] search_list = [ *module.functions, *module.imports, ] called_func_list = [ x for x in search_list if x.name == called_name ] assert 1 == len(called_func_list), \ 'Could not find function {}'.format(node.func.id) return self.visit_Call_func(module, wlocals, exp_type, node, called_func_list[0]) def visit_Call_class( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Call, cls: wasm.OurTypeClass, ) -> StatementGenerator: """ Visits a Call node as (part of) an expression This instantiates the class """ # TODO: free call yield wasm.Statement('i32.const', str(cls.alloc_size())) yield wasm.Statement('call', '$___new_reference___') yield wasm.Statement('local.set', '$___new_reference___addr') for member, arg in zip(cls.members, node.args): if not isinstance(arg, ast.Constant): raise NotImplementedError yield wasm.Statement('local.get', '$___new_reference___addr') yield wasm.Statement(f'{member.type.to_wasm()}.const', str(arg.value)) yield wasm.Statement(f'{member.type.to_wasm()}.store', 'offset=' + str(member.offset)) yield wasm.Statement('local.get', '$___new_reference___addr') def visit_Call_func( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Call, func: Union[wasm.Function, wasm.Import], ) -> StatementGenerator: """ Visits a Call node as (part of) an expression """ assert isinstance(node.func, ast.Name) called_name = node.func.id called_params = func.params called_result = func.result assert exp_type == called_result, 'Function does not match expected type' assert len(called_params) == len(node.args), \ '{}:{} Function {} requires {} arguments, but {} are supplied'.format( node.lineno, node.col_offset, called_name, len(called_params), len(node.args), ) for (_, exp_expr_type), expr in zip(called_params, node.args): yield from self.visit_expr(module, wlocals, exp_expr_type, expr) yield wasm.Statement( 'call', '${}'.format(called_name), ) def visit_Constant(self, exp_type: wasm.OurType, node: ast.Constant) -> StatementGenerator: """ Visits a Constant node as (part of) an expression """ if isinstance(exp_type, wasm.OurTypeInt32): assert isinstance(node.value, int) assert -2147483648 <= node.value <= 2147483647 yield wasm.Statement('i32.const', str(node.value)) return if isinstance(exp_type, wasm.OurTypeInt64): assert isinstance(node.value, int) assert -9223372036854775808 <= node.value <= 9223372036854775807 yield wasm.Statement('i64.const', str(node.value)) return if isinstance(exp_type, wasm.OurTypeFloat32): assert isinstance(node.value, float) # TODO: Size check? yield wasm.Statement('f32.const', node.value.hex()) return if isinstance(exp_type, wasm.OurTypeFloat64): assert isinstance(node.value, float) # TODO: Size check? yield wasm.Statement('f64.const', node.value.hex()) return raise NotImplementedError(exp_type) def visit_Attribute( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Attribute, ) -> StatementGenerator: """ Visits an Attribute node as (part of) an expression """ if not isinstance(node.value, ast.Name): raise NotImplementedError if not isinstance(node.ctx, ast.Load): raise NotImplementedError cls = wlocals[node.value.id] assert isinstance(cls, wasm.OurTypeClass), f'Cannot take property of {cls}' member_list = [ x for x in cls.members if x.name == node.attr ] assert len(member_list) == 1 member = member_list[0] yield wasm.Statement('local.get', '$' + node.value.id) yield wasm.Statement(exp_type.to_wasm() + '.load', 'offset=' + str(member.offset)) def visit_Subscript( self, module: wasm.Module, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Subscript, ) -> StatementGenerator: """ Visits an Subscript node as (part of) an expression """ if not isinstance(node.value, ast.Name): raise NotImplementedError if not isinstance(node.slice, ast.Index): raise NotImplementedError if not isinstance(node.slice.value, ast.Constant): raise NotImplementedError if not isinstance(node.slice.value.value, int): raise NotImplementedError if not isinstance(node.ctx, ast.Load): raise NotImplementedError tpl_idx = node.slice.value.value tpl = wlocals[node.value.id] assert isinstance(tpl, wasm.OurTypeTuple), f'Cannot take index of {tpl}' assert 0 <= tpl_idx < len(tpl.members), \ f'Tuple index out of bounds; {node.value.id} has {len(tpl.members)} members' member = tpl.members[tpl_idx] yield wasm.Statement('local.get', '$' + node.value.id) yield wasm.Statement(exp_type.to_wasm() + '.load', 'offset=' + str(member.offset)) def visit_Tuple( self, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Tuple, ) -> StatementGenerator: """ Visits an Tuple node as (part of) an expression """ assert isinstance(exp_type, wasm.OurTypeTuple), 'Expression is not expecting a tuple' # TODO: free call tpl = exp_type yield wasm.Statement('nop', comment='Start tuple allocation') yield wasm.Statement('i32.const', str(tpl.alloc_size())) yield wasm.Statement('call', '$___new_reference___') yield wasm.Statement('local.set', '$___new_reference___addr', comment='Allocate for tuple') for member, arg in zip(tpl.members, node.elts): if not isinstance(arg, ast.Constant): raise NotImplementedError('TODO: Non-const tuple members') yield wasm.Statement('local.get', '$___new_reference___addr') yield wasm.Statement(f'{member.type.to_wasm()}.const', str(arg.value)) yield wasm.Statement(f'{member.type.to_wasm()}.store', 'offset=' + str(member.offset), comment='Write tuple value to memory') yield wasm.Statement('local.get', '$___new_reference___addr', comment='Store tuple address on stack') def visit_Name(self, wlocals: WLocals, exp_type: wasm.OurType, node: ast.Name) -> StatementGenerator: """ Visits a Name node as (part of) an expression """ assert node.id in wlocals if (isinstance(exp_type, wasm.OurTypeInt64) and isinstance(wlocals[node.id], wasm.OurTypeInt32)): yield wasm.Statement('local.get', '${}'.format(node.id)) yield wasm.Statement('i64.extend_i32_s') return if (isinstance(exp_type, wasm.OurTypeFloat64) and isinstance(wlocals[node.id], wasm.OurTypeFloat32)): yield wasm.Statement('local.get', '${}'.format(node.id)) yield wasm.Statement('f64.promote_f32') return assert exp_type == wlocals[node.id], (exp_type, wlocals[node.id], ) yield wasm.Statement( 'local.get', '${}'.format(node.id), ) def _parse_import_decorator(self, func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]: """ Parses an @import decorator """ assert 0 < len(args) < 3 str_args = [ arg.value for arg in args if isinstance(arg, ast.Constant) and isinstance(arg.value, str) ] assert len(str_args) == len(args) module = str_args.pop(0) if str_args: name = str_args.pop(0) else: name = func_name return module, name, func_name def _parse_import_args(self, args: ast.arguments) -> List[wasm.Param]: """ Parses the arguments for an @imported method """ assert not args.vararg assert not args.kwonlyargs assert not args.kw_defaults assert not args.kwarg assert not args.defaults # Maybe support this later on arg_list = [ *args.posonlyargs, *args.args, ] return [ (arg.arg, self._get_type(arg.annotation)) for arg in arg_list ]