MVP #1

Merged
jbwdevries merged 73 commits from idea_crc32 into master 2022-08-21 12:59:21 +00:00
4 changed files with 145 additions and 137 deletions
Showing only changes of commit edd12e5b7c - Show all commits

View File

@ -2,149 +2,18 @@ import _ast
import ast
import sys
class Import:
def __init__(self, module, name, intname):
self.module = module
self.name = name
self.intname = intname
self.params = None
def generate(self):
return '(import "{}" "{}" (func ${}{}))'.format(
self.module,
self.name,
self.intname,
''.join(' (param {})'.format(x) for x in self.params)
)
class Statement:
def __init__(self, name, *args):
self.name = name
self.args = args
def generate(self):
return '{} {}'.format(self.name, ' '.join(self.args))
class Function:
def __init__(self, name, exported=True):
self.name = name
self.exported = exported # TODO: Use __all__!
self.statements = []
def generate(self):
return '(func {}\n {})'.format(
('(export "{}")' if self.exported else '${}').format(self.name),
'\n '.join(x.generate() for x in self.statements),
)
class Visitor(ast.NodeVisitor):
def __init__(self):
self._stack = []
self.imports = []
self.functions = []
def visit_ImportFrom(self, node):
for alias in node.names:
self.imports.append(Import(
node.module,
alias.name,
alias.asname,
))
def visit_FunctionDef(self, node):
if node.decorator_list:
# TODO: Support normal decorators
assert 1 == len(node.decorator_list)
call = node.decorator_list[0]
assert 'external' == call.func.id
assert 1 == len(call.args)
assert isinstance(call.args[0].value, str)
import_ = Import(
call.args[0].value,
node.name,
node.name,
)
import_.params = [
arg.annotation.id
for arg in node.args.args
]
self.imports.append(import_)
return
func = Function(
node.name,
)
self._stack.append(func)
self.generic_visit(node)
self._stack.pop()
self.functions.append(func)
def visit_Expr(self, node):
self.generic_visit(node)
def visit_Call(self, node):
self.generic_visit(node)
func = self._stack[-1]
func.statements.append(
Statement('call', '$' + node.func.id)
)
def visit_BinOp(self, node):
self.generic_visit(node)
func = self._stack[-1]
if 'Add' == node.op.__class__.__name__:
func.statements.append(
Statement('i32.add')
)
elif 'Mult' == node.op.__class__.__name__:
func.statements.append(
Statement('i32.mul')
)
else:
err(node.op)
def visit_Constant(self, node):
if not self._stack:
# Constant outside of any function
imp = self.imports[-1]
prefix = imp.name + '('
val = node.value.strip()
if val.startswith(prefix) and val.endswith(')'):
imp.params = val[len(prefix):-1].split(',')
else:
func = self._stack[-1]
if isinstance(node.value, int):
func.statements.append(
Statement('i32.const', str(node.value))
)
self.generic_visit(node)
def generate(self):
return '(module\n {}\n {})'.format(
'\n '.join(x.generate() for x in self.imports),
'\n '.join(x.generate() for x in self.functions),
)
def err(msg: str) -> None:
sys.stderr.write('{}\n'.format(msg))
from py2wasm.wasm import Module
from py2wasm.python import Visitor
def process(input: str, input_name: str) -> str:
res = ast.parse(input, input_name)
visitor = Visitor()
module = Module()
visitor = Visitor(module)
visitor.visit(res)
return visitor.generate()
return module.generate()
def main(source: str, sink: str) -> int:
with open(source, 'r') as fil:

0
py2wasm/__init__.py Normal file
View File

94
py2wasm/python.py Normal file
View File

@ -0,0 +1,94 @@
import ast
from .wasm import Function, Import, Module, Statement
class Visitor(ast.NodeVisitor):
def __init__(self, module: Module) -> None:
self._stack = []
self.module = module
# def visit_ImportFrom(self, node):
# for alias in node.names:
# self.imports.append(Import(
# node.module,
# alias.name,
# alias.asname,
# ))
def visit_FunctionDef(self, node):
if node.decorator_list:
# TODO: Support normal decorators
assert 1 == len(node.decorator_list)
call = node.decorator_list[0]
assert 'external' == call.func.id
assert 1 == len(call.args)
assert isinstance(call.args[0].value, str)
import_ = Import(
call.args[0].value,
node.name,
node.name,
)
import_.params = [
arg.annotation.id
for arg in node.args.args
]
self.module.imports.append(import_)
return
func = Function(
node.name,
)
self._stack.append(func)
self.generic_visit(node)
self._stack.pop()
self.module.functions.append(func)
def visit_Expr(self, node):
self.generic_visit(node)
def visit_Call(self, node):
self.generic_visit(node)
func = self._stack[-1]
func.statements.append(
Statement('call', '$' + node.func.id)
)
def visit_BinOp(self, node):
self.generic_visit(node)
func = self._stack[-1]
if 'Add' == node.op.__class__.__name__:
func.statements.append(
Statement('i32.add')
)
elif 'Mult' == node.op.__class__.__name__:
func.statements.append(
Statement('i32.mul')
)
else:
raise NotImplementedError
def visit_Constant(self, node):
if not self._stack:
# Constant outside of any function
imp = self.imports[-1]
prefix = imp.name + '('
val = node.value.strip()
if val.startswith(prefix) and val.endswith(')'):
imp.params = val[len(prefix):-1].split(',')
else:
func = self._stack[-1]
if isinstance(node.value, int):
func.statements.append(
Statement('i32.const', str(node.value))
)
self.generic_visit(node)

45
py2wasm/wasm.py Normal file
View File

@ -0,0 +1,45 @@
class Import:
def __init__(self, module, name, intname):
self.module = module
self.name = name
self.intname = intname
self.params = None
def generate(self):
return '(import "{}" "{}" (func ${}{}))'.format(
self.module,
self.name,
self.intname,
''.join(' (param {})'.format(x) for x in self.params)
)
class Statement:
def __init__(self, name, *args):
self.name = name
self.args = args
def generate(self):
return '{} {}'.format(self.name, ' '.join(self.args))
class Function:
def __init__(self, name, exported=True):
self.name = name
self.exported = exported # TODO: Use __all__!
self.statements = []
def generate(self):
return '(func {}\n {})'.format(
('(export "{}")' if self.exported else '${}').format(self.name),
'\n '.join(x.generate() for x in self.statements),
)
class Module:
def __init__(self, imports = None, functions = None) -> None:
self.imports = imports or []
self.functions = functions or []
def generate(self):
return '(module\n {}\n {})'.format(
'\n '.join(x.generate() for x in self.imports),
'\n '.join(x.generate() for x in self.functions),
)