547 lines
16 KiB
Python
547 lines
16 KiB
Python
"""
|
|
Code for parsing the source (python-alike) code
|
|
"""
|
|
|
|
from typing import Dict, Generator, List, Optional, Union, Tuple
|
|
|
|
import ast
|
|
|
|
from . import wasm
|
|
|
|
# pylint: disable=C0103,R0201
|
|
|
|
StatementGenerator = Generator[wasm.Statement, None, None]
|
|
|
|
WLocals = Dict[str, str]
|
|
|
|
class Visitor:
|
|
"""
|
|
Class to visit a Python syntax tree
|
|
|
|
Since we need to visit the whole tree, there's no point in subclassing
|
|
the buildin visitor
|
|
"""
|
|
def visit_Module(self, node: ast.Module) -> wasm.Module:
|
|
"""
|
|
Visits a Python module, which results in a wasm Module
|
|
"""
|
|
assert not node.type_ignores
|
|
|
|
module = wasm.Module()
|
|
|
|
# Do a check first for all function definitions
|
|
# to get their types. Otherwise you cannot call
|
|
# a method that you haven't defined just yet,
|
|
# even if it is in the same file
|
|
function_body_map: Dict[ast.FunctionDef, wasm.Function] = {}
|
|
for stmt in node.body:
|
|
if isinstance(stmt, ast.FunctionDef):
|
|
wnode = self.pre_visit_FunctionDef(module, stmt)
|
|
|
|
if isinstance(wnode, wasm.Import):
|
|
module.imports.append(wnode)
|
|
else:
|
|
module.functions.append(wnode)
|
|
function_body_map[stmt] = wnode
|
|
continue
|
|
|
|
if isinstance(stmt, ast.ClassDef):
|
|
wclass = self.pre_visit_ClassDef(module, stmt)
|
|
module.classes.append(wclass)
|
|
continue
|
|
|
|
# No other pre visits to do
|
|
|
|
for stmt in node.body:
|
|
if isinstance(stmt, ast.FunctionDef):
|
|
if stmt in function_body_map:
|
|
self.parse_FunctionDef_body(module, function_body_map[stmt], stmt)
|
|
# else: It's an import, no actual body to parse
|
|
continue
|
|
|
|
if isinstance(stmt, ast.ClassDef):
|
|
continue
|
|
|
|
raise NotImplementedError(stmt)
|
|
|
|
return module
|
|
|
|
def pre_visit_FunctionDef(
|
|
self,
|
|
module: wasm.Module,
|
|
node: ast.FunctionDef,
|
|
) -> Union[wasm.Import, wasm.Function]:
|
|
"""
|
|
A Python function definition with the @external decorator is returned as import.
|
|
|
|
Other functions are returned as (exported) functions.
|
|
|
|
Nested / dynamicly created functions are not yet supported
|
|
"""
|
|
exported = False
|
|
|
|
if node.decorator_list:
|
|
assert 1 == len(node.decorator_list)
|
|
|
|
decorator = node.decorator_list[0]
|
|
if isinstance(decorator, ast.Name):
|
|
assert 'exported' == decorator.id
|
|
exported = True
|
|
elif isinstance(decorator, ast.Call):
|
|
call = decorator
|
|
|
|
assert not node.type_comment
|
|
assert 1 == len(node.body)
|
|
assert isinstance(node.body[0], ast.Expr)
|
|
assert isinstance(node.body[0].value, ast.Ellipsis)
|
|
|
|
assert isinstance(call.func, ast.Name)
|
|
assert 'imported' == call.func.id
|
|
|
|
assert not call.keywords
|
|
|
|
mod, name, intname = _parse_import_decorator(node.name, call.args)
|
|
|
|
return wasm.Import(mod, name, intname, _parse_import_args(node.args))
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
result = None if node.returns is None else _parse_annotation(node.returns)
|
|
|
|
assert not node.args.vararg
|
|
assert not node.args.kwonlyargs
|
|
assert not node.args.kw_defaults
|
|
assert not node.args.kwarg
|
|
assert not node.args.defaults
|
|
|
|
class_lookup = {
|
|
x.name: x
|
|
for x in module.classes
|
|
}
|
|
|
|
params = []
|
|
for arg in [*node.args.posonlyargs, *node.args.args]:
|
|
if not isinstance(arg.annotation, ast.Name):
|
|
raise NotImplementedError
|
|
|
|
print(class_lookup)
|
|
print(arg.annotation.id)
|
|
if arg.annotation.id in class_lookup:
|
|
params.append((arg.arg, arg.annotation.id, ))
|
|
else:
|
|
params.append((arg.arg, _parse_annotation(arg.annotation), ))
|
|
|
|
return wasm.Function(node.name, exported, params, result, [])
|
|
|
|
def parse_FunctionDef_body(
|
|
self,
|
|
module: wasm.Module,
|
|
func: wasm.Function,
|
|
node: ast.FunctionDef,
|
|
) -> None:
|
|
"""
|
|
Parses the function body
|
|
"""
|
|
wlocals: WLocals = dict(func.params)
|
|
|
|
statements: List[wasm.Statement] = []
|
|
for py_stmt in node.body:
|
|
statements.extend(
|
|
self.visit_stmt(module, node, wlocals, py_stmt)
|
|
)
|
|
|
|
func.statements = statements
|
|
|
|
def pre_visit_ClassDef(
|
|
self,
|
|
module: wasm.Module,
|
|
node: ast.ClassDef,
|
|
) -> wasm.Class:
|
|
"""
|
|
TODO: Document this
|
|
"""
|
|
del module
|
|
|
|
if node.bases or node.keywords or node.decorator_list:
|
|
raise NotImplementedError
|
|
|
|
members: List[wasm.ClassMember] = []
|
|
|
|
for stmt in node.body:
|
|
if not isinstance(stmt, ast.AnnAssign):
|
|
raise NotImplementedError
|
|
|
|
if not isinstance(stmt.target, ast.Name):
|
|
raise NotImplementedError
|
|
|
|
if not isinstance(stmt.annotation, ast.Name):
|
|
raise NotImplementedError
|
|
|
|
if stmt.annotation.id != 'i32':
|
|
raise NotImplementedError
|
|
|
|
if stmt.value is None:
|
|
default = None
|
|
else:
|
|
if not isinstance(stmt.value, ast.Constant):
|
|
raise NotImplementedError
|
|
|
|
if not isinstance(stmt.value.value, int):
|
|
raise NotImplementedError
|
|
|
|
default = wasm.Constant(stmt.value.value)
|
|
|
|
members.append(wasm.ClassMember(
|
|
stmt.target.id, stmt.annotation.id, default
|
|
))
|
|
|
|
return wasm.Class(node.name, members)
|
|
|
|
def visit_stmt(
|
|
self,
|
|
module: wasm.Module,
|
|
func: ast.FunctionDef,
|
|
wlocals: WLocals,
|
|
stmt: ast.stmt,
|
|
) -> StatementGenerator:
|
|
"""
|
|
Visits a statement node within a function
|
|
"""
|
|
if isinstance(stmt, ast.Return):
|
|
return self.visit_Return(module, func, wlocals, stmt)
|
|
|
|
if isinstance(stmt, ast.Expr):
|
|
assert isinstance(stmt.value, ast.Call)
|
|
|
|
return self.visit_Call(module, wlocals, "None", stmt.value)
|
|
|
|
if isinstance(stmt, ast.If):
|
|
return self.visit_If(module, func, wlocals, stmt)
|
|
|
|
raise NotImplementedError(stmt)
|
|
|
|
def visit_Return(
|
|
self,
|
|
module: wasm.Module,
|
|
func: ast.FunctionDef,
|
|
wlocals: WLocals,
|
|
stmt: ast.Return,
|
|
) -> StatementGenerator:
|
|
"""
|
|
Visits a Return node
|
|
"""
|
|
|
|
assert stmt.value is not None
|
|
|
|
return_type = _parse_annotation(func.returns)
|
|
|
|
yield from self.visit_expr(module, wlocals, return_type, stmt.value)
|
|
yield wasm.Statement('return')
|
|
|
|
def visit_If(
|
|
self,
|
|
module: wasm.Module,
|
|
func: ast.FunctionDef,
|
|
wlocals: WLocals,
|
|
stmt: ast.If,
|
|
) -> StatementGenerator:
|
|
"""
|
|
Visits an If node
|
|
"""
|
|
yield from self.visit_expr(
|
|
module,
|
|
wlocals,
|
|
'bool',
|
|
stmt.test,
|
|
)
|
|
|
|
yield wasm.Statement('if')
|
|
|
|
for py_stmt in stmt.body:
|
|
yield from self.visit_stmt(module, func, wlocals, py_stmt)
|
|
|
|
yield wasm.Statement('else')
|
|
|
|
for py_stmt in stmt.orelse:
|
|
yield from self.visit_stmt(module, func, wlocals, py_stmt)
|
|
|
|
yield wasm.Statement('end')
|
|
|
|
def visit_expr(
|
|
self,
|
|
module: wasm.Module,
|
|
wlocals: WLocals,
|
|
exp_type: str,
|
|
node: ast.expr,
|
|
) -> StatementGenerator:
|
|
"""
|
|
Visit an expression node
|
|
"""
|
|
if isinstance(node, ast.BinOp):
|
|
return self.visit_BinOp(module, wlocals, exp_type, node)
|
|
|
|
if isinstance(node, ast.UnaryOp):
|
|
return self.visit_UnaryOp(module, wlocals, exp_type, node)
|
|
|
|
if isinstance(node, ast.Compare):
|
|
return self.visit_Compare(module, wlocals, exp_type, node)
|
|
|
|
if isinstance(node, ast.Call):
|
|
return self.visit_Call(module, wlocals, exp_type, node)
|
|
|
|
if isinstance(node, ast.Constant):
|
|
return self.visit_Constant(exp_type, node)
|
|
|
|
if isinstance(node, ast.Name):
|
|
return self.visit_Name(wlocals, exp_type, node)
|
|
|
|
if isinstance(node, ast.Attribute):
|
|
return [] # TODO
|
|
|
|
raise NotImplementedError(node)
|
|
|
|
def visit_UnaryOp(
|
|
self,
|
|
module: wasm.Module,
|
|
wlocals: WLocals,
|
|
exp_type: str,
|
|
node: ast.UnaryOp,
|
|
) -> StatementGenerator:
|
|
"""
|
|
Visits a UnaryOp node as (part of) an expression
|
|
"""
|
|
if isinstance(node.op, ast.UAdd):
|
|
return self.visit_expr(module, wlocals, exp_type, node.operand)
|
|
|
|
if isinstance(node.op, ast.USub):
|
|
if not isinstance(node.operand, ast.Constant):
|
|
raise NotImplementedError(node.operand)
|
|
|
|
if not isinstance(node.operand.value, int):
|
|
raise NotImplementedError(node.operand.value)
|
|
|
|
return self.visit_Constant(exp_type, ast.Constant(-node.operand.value))
|
|
|
|
raise NotImplementedError(node.op)
|
|
|
|
def visit_BinOp(
|
|
self,
|
|
module: wasm.Module,
|
|
wlocals: WLocals,
|
|
exp_type: str,
|
|
node: ast.BinOp,
|
|
) -> StatementGenerator:
|
|
"""
|
|
Visits a BinOp node as (part of) an expression
|
|
"""
|
|
yield from self.visit_expr(module, wlocals, exp_type, node.left)
|
|
yield from self.visit_expr(module, wlocals, exp_type, node.right)
|
|
|
|
if isinstance(node.op, ast.Add):
|
|
yield wasm.Statement('{}.add'.format(exp_type))
|
|
return
|
|
|
|
if isinstance(node.op, ast.Sub):
|
|
yield wasm.Statement('{}.sub'.format(exp_type))
|
|
return
|
|
|
|
raise NotImplementedError(node.op)
|
|
|
|
def visit_Compare(
|
|
self,
|
|
module: wasm.Module,
|
|
wlocals: WLocals,
|
|
exp_type: str,
|
|
node: ast.Compare,
|
|
) -> StatementGenerator:
|
|
"""
|
|
Visits a Compare node as (part of) an expression
|
|
"""
|
|
assert 'bool' == exp_type
|
|
|
|
if 1 != len(node.ops) or 1 != len(node.comparators):
|
|
raise NotImplementedError
|
|
|
|
yield from self.visit_expr(module, wlocals, 'i32', node.left)
|
|
yield from self.visit_expr(module, wlocals, 'i32', node.comparators[0])
|
|
|
|
if isinstance(node.ops[0], ast.Lt):
|
|
yield wasm.Statement('i32.lt_s')
|
|
return
|
|
|
|
if isinstance(node.ops[0], ast.Eq):
|
|
yield wasm.Statement('i32.eq')
|
|
return
|
|
|
|
if isinstance(node.ops[0], ast.Gt):
|
|
yield wasm.Statement('i32.gt_s')
|
|
return
|
|
|
|
raise NotImplementedError(node.ops)
|
|
|
|
def visit_Call(
|
|
self,
|
|
module: wasm.Module,
|
|
wlocals: WLocals,
|
|
exp_type: str,
|
|
node: ast.Call,
|
|
) -> StatementGenerator:
|
|
"""
|
|
Visits a Call node as (part of) an expression
|
|
"""
|
|
assert isinstance(node.func, ast.Name)
|
|
assert not node.keywords
|
|
|
|
called_name = node.func.id
|
|
|
|
search_list: List[Union[wasm.Function, wasm.Import, wasm.Class]]
|
|
search_list = [
|
|
*module.functions,
|
|
*module.imports,
|
|
*module.classes,
|
|
]
|
|
|
|
called_func_list = [
|
|
x
|
|
for x in search_list
|
|
if x.name == called_name
|
|
]
|
|
|
|
assert 1 == len(called_func_list), \
|
|
'Could not find function {}'.format(node.func.id)
|
|
|
|
if isinstance(called_func_list[0], wasm.Class):
|
|
called_params = [
|
|
(x.name, x.type, )
|
|
for x in called_func_list[0].members
|
|
]
|
|
called_result: Optional[str] = called_func_list[0].name
|
|
else:
|
|
called_params = called_func_list[0].params
|
|
called_result = called_func_list[0].result
|
|
|
|
assert exp_type == called_result
|
|
|
|
assert len(called_params) == len(node.args), \
|
|
'{}:{} Function {} requires {} arguments, but {} are supplied'.format(
|
|
node.lineno, node.col_offset,
|
|
called_name, len(called_params), len(node.args),
|
|
)
|
|
|
|
for (_, exp_expr_type), expr in zip(called_params, node.args):
|
|
yield from self.visit_expr(module, wlocals, exp_expr_type, expr)
|
|
|
|
yield wasm.Statement(
|
|
'call',
|
|
'${}'.format(called_name),
|
|
)
|
|
|
|
def visit_Constant(self, exp_type: str, node: ast.Constant) -> StatementGenerator:
|
|
"""
|
|
Visits a Constant node as (part of) an expression
|
|
"""
|
|
if 'i32' == exp_type:
|
|
assert isinstance(node.value, int)
|
|
assert -2147483648 <= node.value <= 2147483647
|
|
|
|
yield wasm.Statement('i32.const', str(node.value))
|
|
return
|
|
|
|
if 'i64' == exp_type:
|
|
assert isinstance(node.value, int)
|
|
assert -9223372036854775808 <= node.value <= 9223372036854775807
|
|
|
|
yield wasm.Statement('i64.const', str(node.value))
|
|
return
|
|
|
|
if 'f32' == exp_type:
|
|
assert isinstance(node.value, float)
|
|
# TODO: Size check?
|
|
|
|
yield wasm.Statement('f32.const', node.value.hex())
|
|
return
|
|
|
|
if 'f64' == exp_type:
|
|
assert isinstance(node.value, float)
|
|
# TODO: Size check?
|
|
|
|
yield wasm.Statement('f64.const', node.value.hex())
|
|
return
|
|
|
|
raise NotImplementedError(exp_type)
|
|
|
|
def visit_Name(self, wlocals: WLocals, exp_type: str, node: ast.Name) -> StatementGenerator:
|
|
"""
|
|
Visits a Name node as (part of) an expression
|
|
"""
|
|
assert node.id in wlocals
|
|
|
|
if exp_type == 'i64' and wlocals[node.id] == 'i32':
|
|
yield wasm.Statement('local.get', '${}'.format(node.id))
|
|
yield wasm.Statement('i64.extend_i32_s')
|
|
return
|
|
|
|
if exp_type == 'f64' and wlocals[node.id] == 'f32':
|
|
yield wasm.Statement('local.get', '${}'.format(node.id))
|
|
yield wasm.Statement('f64.promote_f32')
|
|
return
|
|
|
|
assert exp_type == wlocals[node.id]
|
|
yield wasm.Statement(
|
|
'local.get',
|
|
'${}'.format(node.id),
|
|
)
|
|
|
|
def _parse_import_decorator(func_name: str, args: List[ast.expr]) -> Tuple[str, str, str]:
|
|
"""
|
|
Parses an @import decorator
|
|
"""
|
|
|
|
assert 0 < len(args) < 3
|
|
str_args = [
|
|
arg.value
|
|
for arg in args
|
|
if isinstance(arg, ast.Constant)
|
|
and isinstance(arg.value, str)
|
|
]
|
|
assert len(str_args) == len(args)
|
|
|
|
module = str_args.pop(0)
|
|
if str_args:
|
|
name = str_args.pop(0)
|
|
else:
|
|
name = func_name
|
|
|
|
return module, name, func_name
|
|
|
|
def _parse_import_args(args: ast.arguments) -> List[Tuple[str, str]]:
|
|
"""
|
|
Parses the arguments for an @imported method
|
|
"""
|
|
|
|
assert not args.vararg
|
|
assert not args.kwonlyargs
|
|
assert not args.kw_defaults
|
|
assert not args.kwarg
|
|
assert not args.defaults # Maybe support this later on
|
|
|
|
arg_list = [
|
|
*args.posonlyargs,
|
|
*args.args,
|
|
]
|
|
|
|
return [
|
|
(arg.arg, _parse_annotation(arg.annotation))
|
|
for arg in arg_list
|
|
]
|
|
|
|
def _parse_annotation(ann: Optional[ast.expr]) -> str:
|
|
"""
|
|
Parses a typing annotation
|
|
"""
|
|
assert ann is not None, 'Web Assembly requires type annotations'
|
|
assert isinstance(ann, ast.Name)
|
|
result = ann.id
|
|
assert result in ['i32', 'i64', 'f32', 'f64'], result
|
|
return result
|