phasm/py2wasm/python.py
Johan B.W. de Vries e972b37149 If statements \o/
2021-08-07 14:34:50 +02:00

410 lines
12 KiB
Python

"""
Code for parsing the source (python-alike) code
"""
from typing import Dict, Generator, List, Optional, Union, Tuple
import ast
from . import wasm
# pylint: disable=C0103,R0201
StatementGenerator = Generator[wasm.Statement, None, None]
WLocals = Dict[str, str]
class Visitor:
"""
Class to visit a Python syntax tree
Since we need to visit the whole tree, there's no point in subclassing
the buildin visitor
"""
def visit_Module(self, node: ast.Module) -> wasm.Module:
"""
Visits a Python module, which results in a wasm Module
"""
assert not node.type_ignores
module = wasm.Module()
for stmt in node.body:
if isinstance(stmt, ast.FunctionDef):
wnode = self.visit_FunctionDef(module, stmt)
if isinstance(wnode, wasm.Import):
module.imports.append(wnode)
else:
module.functions.append(wnode)
continue
raise NotImplementedError(stmt)
return module
def visit_FunctionDef(
self,
module: wasm.Module,
node: ast.FunctionDef,
) -> Union[wasm.Import, wasm.Function]:
"""
A Python function definition with the @external decorator is returned as import.
Other functions are returned as (exported) functions.
Nested / dynamicly created functions are not yet supported
"""
exported = False
if node.decorator_list:
assert 1 == len(node.decorator_list)
decorator = node.decorator_list[0]
if isinstance(decorator, ast.Name):
assert 'exported' == decorator.id
exported = True
elif isinstance(decorator, ast.Call):
call = decorator
assert not node.type_comment
assert 1 == len(node.body)
assert isinstance(node.body[0], ast.Expr)
assert isinstance(node.body[0].value, ast.Ellipsis)
assert isinstance(call.func, ast.Name)
assert 'imported' == call.func.id
assert not call.keywords
mod, name, intname = _parse_import_decorator(node.name, call.args)
return wasm.Import(mod, name, intname, _parse_import_args(node.args))
else:
raise NotImplementedError
result = None if node.returns is None else _parse_annotation(node.returns)
assert not node.args.vararg
assert not node.args.kwonlyargs
assert not node.args.kw_defaults
assert not node.args.kwarg
assert not node.args.defaults
params = [
(a.arg, _parse_annotation(a.annotation), )
for a in [
*node.args.posonlyargs,
*node.args.args,
]
]
wlocals: WLocals = dict(params)
statements: List[wasm.Statement] = []
for py_stmt in node.body:
statements.extend(
self.visit_stmt(module, node, wlocals, py_stmt)
)
return wasm.Function(node.name, exported, params, result, statements)
def visit_stmt(
self,
module: wasm.Module,
func: ast.FunctionDef,
wlocals: WLocals,
stmt: ast.stmt,
) -> StatementGenerator:
"""
Visits a statement node within a function
"""
if isinstance(stmt, ast.Return):
return self.visit_Return(module, func, wlocals, stmt)
if isinstance(stmt, ast.Expr):
assert isinstance(stmt.value, ast.Call)
return self.visit_Call(module, wlocals, "None", stmt.value)
if isinstance(stmt, ast.If):
return self.visit_If(module, func, wlocals, stmt)
raise NotImplementedError(stmt)
def visit_Return(
self,
module: wasm.Module,
func: ast.FunctionDef,
wlocals: WLocals,
stmt: ast.Return,
) -> StatementGenerator:
"""
Visits a Return node
"""
assert stmt.value is not None
return_type = _parse_annotation(func.returns)
yield from self.visit_expr(module, wlocals, return_type, stmt.value)
yield wasm.Statement('return')
def visit_If(
self,
module: wasm.Module,
func: ast.FunctionDef,
wlocals: WLocals,
stmt: ast.If,
) -> StatementGenerator:
"""
Visits an If node
"""
yield from self.visit_expr(
module,
wlocals,
'bool',
stmt.test,
)
yield wasm.Statement('if')
for py_stmt in stmt.body:
yield from self.visit_stmt(module, func, wlocals, py_stmt)
yield wasm.Statement('else')
for py_stmt in stmt.orelse:
yield from self.visit_stmt(module, func, wlocals, py_stmt)
yield wasm.Statement('end')
def visit_expr(
self,
module: wasm.Module,
wlocals: WLocals,
exp_type: str,
node: ast.expr,
) -> StatementGenerator:
"""
Visit an expression node
"""
if isinstance(node, ast.BinOp):
return self.visit_BinOp(module, wlocals, exp_type, node)
if isinstance(node, ast.UnaryOp):
return self.visit_UnaryOp(module, wlocals, exp_type, node)
if isinstance(node, ast.Compare):
return self.visit_Compare(module, wlocals, exp_type, node)
if isinstance(node, ast.Call):
return self.visit_Call(module, wlocals, exp_type, node)
if isinstance(node, ast.Constant):
return self.visit_Constant(exp_type, node)
if isinstance(node, ast.Name):
return self.visit_Name(wlocals, exp_type, node)
raise NotImplementedError(node)
def visit_UnaryOp(
self,
module: wasm.Module,
wlocals: WLocals,
exp_type: str,
node: ast.UnaryOp,
) -> StatementGenerator:
"""
Visits a UnaryOp node as (part of) an expression
"""
del module
del wlocals
if not isinstance(node.operand, ast.Constant):
raise NotImplementedError
return self.visit_Constant(exp_type, node.operand)
def visit_BinOp(
self,
module: wasm.Module,
wlocals: WLocals,
exp_type: str,
node: ast.BinOp,
) -> StatementGenerator:
"""
Visits a BinOp node as (part of) an expression
"""
yield from self.visit_expr(module, wlocals, exp_type, node.left)
yield from self.visit_expr(module, wlocals, exp_type, node.right)
if isinstance(node.op, ast.Add):
yield wasm.Statement('{}.add'.format(exp_type))
return
raise NotImplementedError(node.op)
def visit_Compare(
self,
module: wasm.Module,
wlocals: WLocals,
exp_type: str,
node: ast.Compare,
) -> StatementGenerator:
"""
Visits a Compare node as (part of) an expression
"""
assert 'bool' == exp_type
if 1 != len(node.ops) or 1 != len(node.comparators):
raise NotImplementedError
yield from self.visit_expr(module, wlocals, 'i32', node.left)
yield from self.visit_expr(module, wlocals, 'i32', node.comparators[0])
if isinstance(node.ops[0], ast.Lt):
yield wasm.Statement('i32.lt_s')
return
if isinstance(node.ops[0], ast.Gt):
yield wasm.Statement('i32.gt_s')
return
raise NotImplementedError(node.ops)
def visit_Call(
self,
module: wasm.Module,
wlocals: WLocals,
exp_type: str,
node: ast.Call,
) -> StatementGenerator:
"""
Visits a Call node as (part of) an expression
"""
assert isinstance(node.func, ast.Name)
assert not node.keywords
called_name = node.func.id
called_func_list = [
x
for x in module.functions
if x.name == called_name
]
if called_func_list:
assert len(called_func_list) == 1
called_params = called_func_list[0].params
called_result = called_func_list[0].result
else:
called_import_list = [
x
for x in module.imports
if x.name == node.func.id
]
if called_import_list:
assert len(called_import_list) == 1
called_params = called_import_list[0].params
called_result = called_import_list[0].result
else:
assert 1 == len(called_func_list), \
'Could not find function {}'.format(node.func.id)
assert exp_type == called_result
assert len(called_params) == len(node.args), \
'{}:{} Function {} requires {} arguments, but {} are supplied'.format(
node.lineno, node.col_offset,
called_name, len(called_params), len(node.args),
)
for (_, exp_expr_type), expr in zip(called_params, node.args):
yield from self.visit_expr(module, wlocals, exp_expr_type, expr)
yield wasm.Statement(
'call',
'${}'.format(called_name),
)
def visit_Constant(self, exp_type: str, node: ast.Constant) -> StatementGenerator:
"""
Visits a Constant node as (part of) an expression
"""
if 'i32' == exp_type:
assert isinstance(node.value, int)
assert -2147483648 <= node.value <= 2147483647
yield wasm.Statement(
'{}.const'.format(exp_type),
str(node.value)
)
return
raise NotImplementedError(exp_type)
def visit_Name(self, wlocals: WLocals, exp_type: str, node: ast.Name) -> StatementGenerator:
"""
Visits a Name node as (part of) an expression
"""
assert node.id in wlocals
assert exp_type == wlocals[node.id]
yield wasm.Statement(
'local.get',
'${}'.format(node.id),
)
def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]:
"""
Parses an @import decorator
"""
assert 0 < len(args) < 3
str_args = [
arg.value
for arg in args
if isinstance(arg, ast.Constant)
and isinstance(arg.value, str)
]
assert len(str_args) == len(args)
module = str_args.pop(0)
if str_args:
name = str_args.pop(0)
else:
name = func_name
return module, name, func_name
def _parse_import_args(args: ast.arguments) -> List[Tuple[str, str]]:
"""
Parses the arguments for an @imported method
"""
assert not args.vararg
assert not args.kwonlyargs
assert not args.kw_defaults
assert not args.kwarg
assert not args.defaults # Maybe support this later on
arg_list = [
*args.posonlyargs,
*args.args,
]
return [
(arg.arg, _parse_annotation(arg.annotation))
for arg in arg_list
]
def _parse_annotation(ann: Optional[ast.expr]) -> str:
"""
Parses a typing annotation
"""
assert ann is not None, 'Web Assembly requires type annotations'
assert isinstance(ann, ast.Name)
result = ann.id
assert result in ['i32']
return result