diff --git a/compile.py b/compile.py index f2c6119..50a6168 100644 --- a/compile.py +++ b/compile.py @@ -2,149 +2,18 @@ import _ast import ast import sys -class Import: - def __init__(self, module, name, intname): - self.module = module - self.name = name - self.intname = intname - self.params = None - - def generate(self): - return '(import "{}" "{}" (func ${}{}))'.format( - self.module, - self.name, - self.intname, - ''.join(' (param {})'.format(x) for x in self.params) - ) - -class Statement: - def __init__(self, name, *args): - self.name = name - self.args = args - - def generate(self): - return '{} {}'.format(self.name, ' '.join(self.args)) - -class Function: - def __init__(self, name, exported=True): - self.name = name - self.exported = exported # TODO: Use __all__! - self.statements = [] - - def generate(self): - return '(func {}\n {})'.format( - ('(export "{}")' if self.exported else '${}').format(self.name), - '\n '.join(x.generate() for x in self.statements), - ) - -class Visitor(ast.NodeVisitor): - def __init__(self): - self._stack = [] - self.imports = [] - self.functions = [] - - def visit_ImportFrom(self, node): - for alias in node.names: - self.imports.append(Import( - node.module, - alias.name, - alias.asname, - )) - - def visit_FunctionDef(self, node): - if node.decorator_list: - # TODO: Support normal decorators - assert 1 == len(node.decorator_list) - call = node.decorator_list[0] - assert 'external' == call.func.id - assert 1 == len(call.args) - assert isinstance(call.args[0].value, str) - - import_ = Import( - call.args[0].value, - node.name, - node.name, - ) - - import_.params = [ - arg.annotation.id - for arg in node.args.args - ] - - self.imports.append(import_) - return - - func = Function( - node.name, - ) - - self._stack.append(func) - self.generic_visit(node) - self._stack.pop() - - self.functions.append(func) - - def visit_Expr(self, node): - self.generic_visit(node) - - def visit_Call(self, node): - self.generic_visit(node) - - func = self._stack[-1] - func.statements.append( - Statement('call', '$' + node.func.id) - ) - - def visit_BinOp(self, node): - self.generic_visit(node) - - func = self._stack[-1] - - if 'Add' == node.op.__class__.__name__: - func.statements.append( - Statement('i32.add') - ) - elif 'Mult' == node.op.__class__.__name__: - func.statements.append( - Statement('i32.mul') - ) - else: - err(node.op) - - def visit_Constant(self, node): - if not self._stack: - # Constant outside of any function - imp = self.imports[-1] - prefix = imp.name + '(' - val = node.value.strip() - - if val.startswith(prefix) and val.endswith(')'): - imp.params = val[len(prefix):-1].split(',') - else: - func = self._stack[-1] - if isinstance(node.value, int): - func.statements.append( - Statement('i32.const', str(node.value)) - ) - - self.generic_visit(node) - - def generate(self): - return '(module\n {}\n {})'.format( - '\n '.join(x.generate() for x in self.imports), - '\n '.join(x.generate() for x in self.functions), - ) - -def err(msg: str) -> None: - sys.stderr.write('{}\n'.format(msg)) +from py2wasm.wasm import Module +from py2wasm.python import Visitor def process(input: str, input_name: str) -> str: res = ast.parse(input, input_name) - visitor = Visitor() + module = Module() + + visitor = Visitor(module) visitor.visit(res) - return visitor.generate() + return module.generate() def main(source: str, sink: str) -> int: with open(source, 'r') as fil: diff --git a/py2wasm/__init__.py b/py2wasm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/py2wasm/python.py b/py2wasm/python.py new file mode 100644 index 0000000..6be9b1e --- /dev/null +++ b/py2wasm/python.py @@ -0,0 +1,94 @@ +import ast + +from .wasm import Function, Import, Module, Statement + +class Visitor(ast.NodeVisitor): + def __init__(self, module: Module) -> None: + self._stack = [] + self.module = module + + # def visit_ImportFrom(self, node): + # for alias in node.names: + # self.imports.append(Import( + # node.module, + # alias.name, + # alias.asname, + # )) + + def visit_FunctionDef(self, node): + if node.decorator_list: + # TODO: Support normal decorators + assert 1 == len(node.decorator_list) + call = node.decorator_list[0] + assert 'external' == call.func.id + assert 1 == len(call.args) + assert isinstance(call.args[0].value, str) + + import_ = Import( + call.args[0].value, + node.name, + node.name, + ) + + import_.params = [ + arg.annotation.id + for arg in node.args.args + ] + + self.module.imports.append(import_) + return + + func = Function( + node.name, + ) + + self._stack.append(func) + self.generic_visit(node) + self._stack.pop() + + self.module.functions.append(func) + + def visit_Expr(self, node): + self.generic_visit(node) + + def visit_Call(self, node): + self.generic_visit(node) + + func = self._stack[-1] + func.statements.append( + Statement('call', '$' + node.func.id) + ) + + def visit_BinOp(self, node): + self.generic_visit(node) + + func = self._stack[-1] + + if 'Add' == node.op.__class__.__name__: + func.statements.append( + Statement('i32.add') + ) + elif 'Mult' == node.op.__class__.__name__: + func.statements.append( + Statement('i32.mul') + ) + else: + raise NotImplementedError + + def visit_Constant(self, node): + if not self._stack: + # Constant outside of any function + imp = self.imports[-1] + prefix = imp.name + '(' + val = node.value.strip() + + if val.startswith(prefix) and val.endswith(')'): + imp.params = val[len(prefix):-1].split(',') + else: + func = self._stack[-1] + if isinstance(node.value, int): + func.statements.append( + Statement('i32.const', str(node.value)) + ) + + self.generic_visit(node) diff --git a/py2wasm/wasm.py b/py2wasm/wasm.py new file mode 100644 index 0000000..f0a3b32 --- /dev/null +++ b/py2wasm/wasm.py @@ -0,0 +1,45 @@ +class Import: + def __init__(self, module, name, intname): + self.module = module + self.name = name + self.intname = intname + self.params = None + + def generate(self): + return '(import "{}" "{}" (func ${}{}))'.format( + self.module, + self.name, + self.intname, + ''.join(' (param {})'.format(x) for x in self.params) + ) + +class Statement: + def __init__(self, name, *args): + self.name = name + self.args = args + + def generate(self): + return '{} {}'.format(self.name, ' '.join(self.args)) + +class Function: + def __init__(self, name, exported=True): + self.name = name + self.exported = exported # TODO: Use __all__! + self.statements = [] + + def generate(self): + return '(func {}\n {})'.format( + ('(export "{}")' if self.exported else '${}').format(self.name), + '\n '.join(x.generate() for x in self.statements), + ) + +class Module: + def __init__(self, imports = None, functions = None) -> None: + self.imports = imports or [] + self.functions = functions or [] + + def generate(self): + return '(module\n {}\n {})'.format( + '\n '.join(x.generate() for x in self.imports), + '\n '.join(x.generate() for x in self.functions), + )