diff --git a/py2wasm/ourlang.py b/py2wasm/ourlang.py new file mode 100644 index 0000000..2898cc2 --- /dev/null +++ b/py2wasm/ourlang.py @@ -0,0 +1,713 @@ +""" +Contains the syntax tree for ourlang +""" +from typing import Any, Dict, List, NoReturn, Union, Tuple + +import ast + +class OurType: + """ + Type base class + """ + __slots__ = () + + def render(self) -> str: + """ + Renders the type back to source code format + + This'll look like Python code. + """ + raise NotImplementedError(self, 'render') + + def alloc_size(self) -> int: + """ + When allocating this type in memory, how many bytes do we need to reserve? + """ + raise NotImplementedError(self, 'alloc_size') + +class OurTypeNone(OurType): + """ + The None (or Void) type + """ + __slots__ = () + + def render(self) -> str: + return 'None' + +class OurTypeInt32(OurType): + """ + The Integer type, signed and 32 bits wide + """ + __slots__ = () + + def render(self) -> str: + return 'i32' + + def alloc_size(self) -> int: + return 4 + +class OurTypeInt64(OurType): + """ + The Integer type, signed and 64 bits wide + """ + __slots__ = () + + def render(self) -> str: + return 'i64' + +class OurTypeFloat32(OurType): + """ + The Float type, 32 bits wide + """ + __slots__ = () + + def render(self) -> str: + return 'f32' + +class OurTypeFloat64(OurType): + """ + The Float type, 64 bits wide + """ + __slots__ = () + + def render(self) -> str: + return 'f64' + +class Expression: + """ + An expression within a statement + """ + __slots__ = () + + def render(self) -> str: + """ + Renders the expression back to source code format + + This'll look like Python code. + """ + raise NotImplementedError(self, 'render') + +class Constant(Expression): + """ + An constant value expression within a statement + """ + __slots__ = ('type', ) + + type: OurType + + def __init__(self, type_: OurType) -> None: + self.type = type_ + +class ConstantInt32(Constant): + """ + An Int32 constant value expression within a statement + """ + __slots__ = ('value', ) + + value: int + + def __init__(self, type_: OurTypeInt32, value: int) -> None: + super().__init__(type_) + self.value = value + + def render(self) -> str: + return str(self.value) + +class ConstantInt64(Constant): + """ + An Int64 constant value expression within a statement + """ + __slots__ = ('value', ) + + value: int + + def __init__(self, type_: OurTypeInt64, value: int) -> None: + super().__init__(type_) + self.value = value + + def render(self) -> str: + return str(self.value) + +class ConstantFloat32(Constant): + """ + An Float32 constant value expression within a statement + """ + __slots__ = ('value', ) + + value: float + + def __init__(self, type_: OurTypeFloat32, value: float) -> None: + super().__init__(type_) + self.value = value + + def render(self) -> str: + return str(self.value) + +class ConstantFloat64(Constant): + """ + An Float64 constant value expression within a statement + """ + __slots__ = ('value', ) + + value: float + + def __init__(self, type_: OurTypeFloat64, value: float) -> None: + super().__init__(type_) + self.value = value + + def render(self) -> str: + return str(self.value) + +class VariableReference(Expression): + """ + An variable reference expression within a statement + """ + __slots__ = ('type', 'name', ) + + type: OurType + name: str + + def __init__(self, type_: OurType, name: str) -> None: + self.type = type_ + self.name = name + + def render(self) -> str: + return str(self.name) + +class BinaryOp(Expression): + """ + A binary operator expression within a statement + """ + __slots__ = ('operator', 'left', 'right', ) + + operator: str + left: Expression + right: Expression + + def __init__(self, operator: str, left: Expression, right: Expression) -> None: + self.operator = operator + self.left = left + self.right = right + + def render(self) -> str: + return f'{self.left.render()} {self.operator} {self.right.render()}' + +class UnaryOp(Expression): + """ + A unary operator expression within a statement + """ + __slots__ = ('operator', 'right', ) + + operator: str + right: Expression + + def __init__(self, operator: str, right: Expression) -> None: + self.operator = operator + self.right = right + + def render(self) -> str: + return f'{self.operator}{self.right.render()}' + +class FunctionCall(Expression): + """ + A function call expression within a statement + """ + __slots__ = ('function', 'arguments', ) + + function: 'Function' + arguments: List[Expression] + + def __init__(self, function: 'Function') -> None: + self.function = function + self.arguments = [] + + def render(self) -> str: + args = ', '.join( + arg.render() + for arg in self.arguments + ) + + return f'{self.function.name}({args})' + +class Statement: + """ + A statement within a function + """ + __slots__ = () + + def render(self) -> List[str]: + """ + Renders the type back to source code format + + This'll look like Python code. + """ + raise NotImplementedError(self, 'render') + +class StatementReturn(Statement): + """ + A return statement within a function + """ + __slots__ = ('value', ) + + def __init__(self, value: Expression) -> None: + self.value = value + + def render(self) -> List[str]: + """ + Renders the type back to source code format + + This'll look like Python code. + """ + return [f'return {self.value.render()}'] + +class StatementIf(Statement): + """ + An if statement within a function + """ + __slots__ = ('test', 'statements', 'else_statements', ) + + test: Expression + statements: List[Statement] + else_statements: List[Statement] + + def __init__(self, test: Expression) -> None: + self.test = test + self.statements = [] + self.else_statements = [] + + def render(self) -> List[str]: + """ + Renders the type back to source code format + + This'll look like Python code. + """ + result = [f'if {self.test.render()}:'] + + for stmt in self.statements: + result.extend( + f' {line}' if line else '' + for line in stmt.render() + ) + + result.append('') + + return result + +class StatementPass(Statement): + """ + A pass statement + """ + __slots__ = () + + def render(self) -> List[str]: + return ['pass'] + +class Function: + """ + A function processes input and produces output + """ + __slots__ = ('name', 'lineno', 'exported', 'statements', 'returns', 'posonlyargs', ) + + name: str + lineno: int + exported: bool + statements: List[Statement] + returns: OurType + posonlyargs: List[Tuple[str, OurType]] + + def __init__(self, name: str, lineno: int) -> None: + self.name = name + self.lineno = lineno + self.exported = False + self.statements = [] + self.returns = OurTypeNone() + self.posonlyargs = [] + + def render(self) -> str: + """ + Renders the function back to source code format + + This'll look like Python code. + """ + result = '' + if self.exported: + result += '@exported\n' + + args = ', '.join( + f'{x}: {y.render()}' + for x, y in self.posonlyargs + ) + + result += f'def {self.name}({args}) -> {self.returns.render()}:\n' + for stmt in self.statements: + for line in stmt.render(): + result += f' {line}\n' if line else '\n' + return result + +class StructMember: + """ + Represents a struct member + """ + def __init__(self, name: str, type_: OurType, offset: int) -> None: + self.name = name + self.type = type_ + self.offset = offset + +class Struct: + """ + A struct has named properties + """ + __slots__ = ('name', 'lineno', 'members', ) + + name: str + lineno: int + members: List[StructMember] + + def __init__(self, name: str, lineno: int) -> None: + self.name = name + self.lineno = lineno + self.members = [] + + def render(self) -> str: + """ + Renders the function back to source code format + + This'll look like Python code. + """ + return '?' + +class Module: + """ + A module is a file and consists of functions + """ + __slots__ = ('types', 'functions', 'structs', ) + + types: Dict[str, OurType] + functions: Dict[str, Function] + structs: Dict[str, Struct] + + def __init__(self) -> None: + self.types = { + 'i32': OurTypeInt32(), + 'i64': OurTypeInt64(), + 'f32': OurTypeFloat32(), + 'f64': OurTypeFloat64(), + } + self.functions = {} + self.structs = {} + + def render(self) -> str: + """ + Renders the module back to source code format + + This'll look like Python code. + """ + result = '' + for function in self.functions.values(): + if result: + result += '\n' + result += function.render() + return result + +class StaticError(Exception): + """ + An error found during static analysis + """ + +OurLocals = Dict[str, OurType] + +class OurVisitor: + """ + Class to visit a Python syntax tree and create an ourlang syntax tree + """ + + # pylint: disable=C0103,C0116,C0301,R0201,R0912 + + def __init__(self) -> None: + pass + + def visit_Module(self, node: ast.Module) -> Module: + module = Module() + + _not_implemented(not node.type_ignores, 'Module.type_ignores') + + for stmt in node.body: + res = self.visit_Module_stmt(module, stmt) + + if isinstance(res, Function): + if res.name in module.functions: + raise StaticError( + f'{res.name} already defined on line {module.functions[res.name].lineno}' + ) + + module.functions[res.name] = res + + if isinstance(res, Struct): + if res.name in module.structs: + raise StaticError( + f'{res.name} already defined on line {module.structs[res.name].lineno}' + ) + + module.structs[res.name] = res + + return module + + def visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, Struct]: + if isinstance(node, ast.FunctionDef): + return self.visit_Module_FunctionDef(module, node) + + if isinstance(node, ast.ClassDef): + return self.visit_Module_ClassDef(module, node) + + raise NotImplementedError(f'{node} on Module') + + def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> Function: + function = Function(node.name, node.lineno) + + _not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs') + + for arg in node.args.args: + if not arg.annotation: + _raise_static_error(node, 'Type is required') + + function.posonlyargs.append(( + arg.arg, + self.visit_type(module, arg.annotation), + )) + + _not_implemented(not node.args.vararg, 'FunctionDef.args.vararg') + _not_implemented(not node.args.kwonlyargs, 'FunctionDef.args.kwonlyargs') + _not_implemented(not node.args.kw_defaults, 'FunctionDef.args.kw_defaults') + _not_implemented(not node.args.kwarg, 'FunctionDef.args.kwarg') + _not_implemented(not node.args.defaults, 'FunctionDef.args.defaults') + + # Do stmts at the end so we have the return value + + for decorator in node.decorator_list: + if not isinstance(decorator, ast.Name): + _raise_static_error(decorator, 'Function decorators must be string') + if not isinstance(decorator.ctx, ast.Load): + _raise_static_error(decorator, 'Must be load context') + _not_implemented(decorator.id != 'exports', 'Custom decorators') + function.exported = True + + if node.returns: + function.returns = self.visit_type(module, node.returns) + + _not_implemented(not node.type_comment, 'FunctionDef.type_comment') + + # Deferred parsing + + our_locals = dict(function.posonlyargs) + + for stmt in node.body: + function.statements.append( + self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) + ) + + return function + + def visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> Struct: + struct = Struct(node.name, node.lineno) + + _not_implemented(not node.bases, 'ClassDef.bases') + _not_implemented(not node.keywords, 'ClassDef.keywords') + _not_implemented(not node.decorator_list, 'ClassDef.decorator_list') + + offset = 0 + + for stmt in node.body: + if not isinstance(stmt, ast.AnnAssign): + raise NotImplementedError(f'Class with {stmt} nodes') + + if not isinstance(stmt.target, ast.Name): + raise NotImplementedError('Class with default values') + + if not stmt.value is None: + raise NotImplementedError('Class with default values') + + if stmt.simple != 1: + raise NotImplementedError('Class with non-simple arguments') + + member = StructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) + + struct.members.append(member) + offset += member.type.alloc_size() + + return struct + + def visit_Module_FunctionDef_stmt(self, module: Module, function: Function, our_locals: OurLocals, node: ast.stmt) -> Statement: + if isinstance(node, ast.Return): + if node.value is None: + # TODO: Implement methods without return values + _raise_static_error(node, 'Return must have an argument') + + return StatementReturn( + self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) + ) + + if isinstance(node, ast.If): + result = StatementIf( + self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) + ) + + for stmt in node.body: + result.statements.append( + self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) + ) + + for stmt in node.body: + result.else_statements.append( + self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) + ) + + return result + + raise NotImplementedError(f'{node} as stmt in FunctionDef') + + def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.expr) -> Expression: + if isinstance(node, ast.BinOp): + if isinstance(node.op, ast.Add): + operator = '+' + elif isinstance(node.op, ast.Sub): + operator = '-' + else: + raise NotImplementedError(f'Operator {node.op}') + + # Assume the type doesn't change when descending into a binary operator + # e.g. you can do `"hello" * 3` with the code below (yet) + + return BinaryOp( + operator, + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), + ) + + if isinstance(node, ast.UnaryOp): + if isinstance(node.op, ast.UAdd): + operator = '+' + elif isinstance(node.op, ast.USub): + operator = '-' + else: + raise NotImplementedError(f'Operator {node.op}') + + return UnaryOp( + operator, + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), + ) + + if isinstance(node, ast.Compare): + if 1 < len(node.ops): + raise NotImplementedError('Multiple operators') + + if isinstance(node.ops[0], ast.Gt): + operator = '>' + else: + raise NotImplementedError(f'Operator {node.ops}') + + # Assume the type doesn't change when descending into a binary operator + # e.g. you can do `"hello" * 3` with the code below (yet) + + return BinaryOp( + operator, + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), + ) + + if isinstance(node, ast.Call): + if node.keywords: + _raise_static_error(node, 'Keyword calling not supported') # Yet? + + if not isinstance(node.func, ast.Name): + raise NotImplementedError(f'Calling methods that are not a name {node.func}') + if not isinstance(node.func.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if node.func.id not in module.functions: + _raise_static_error(node, 'Call to undefined function') + + func = module.functions[node.func.id] + if func.returns != exp_type: + _raise_static_error(node, f'Function does not return {exp_type.render()}') + + result = FunctionCall(func) + result.arguments.extend( + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, expr) + for expr in node.args + ) + return result + + if isinstance(node, ast.Constant): + return self.visit_Module_FunctionDef_Constant( + module, function, exp_type, node, + ) + + if isinstance(node, ast.Name): + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if node.id in our_locals: + return VariableReference(our_locals[node.id], node.id) + + raise NotImplementedError(f'{node} as expr in FunctionDef') + + def visit_Module_FunctionDef_Constant(self, module: Module, function: Function, exp_type: OurType, node: ast.Constant) -> Expression: + del module + del function + + _not_implemented(node.kind is None, 'Constant.kind') + + if isinstance(exp_type, OurTypeInt32): + if not isinstance(node.value, int): + _raise_static_error(node, 'Expected integer value') + + # FIXME: Range check + + return ConstantInt32(exp_type, node.value) + + if isinstance(exp_type, OurTypeInt64): + if not isinstance(node.value, int): + _raise_static_error(node, 'Expected integer value') + + # FIXME: Range check + + return ConstantInt64(exp_type, node.value) + + if isinstance(exp_type, OurTypeFloat32): + if not isinstance(node.value, float): + _raise_static_error(node, 'Expected float value') + + # FIXME: Range check + + return ConstantFloat32(exp_type, node.value) + + if isinstance(exp_type, OurTypeFloat64): + if not isinstance(node.value, float): + _raise_static_error(node, 'Expected float value') + + # FIXME: Range check + + return ConstantFloat64(exp_type, node.value) + + raise NotImplementedError(f'{node} as const for type {exp_type.render()}') + + def visit_type(self, module: Module, node: ast.expr) -> OurType: + if isinstance(node, ast.Name): + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if node.id not in module.types: + _raise_static_error(node, 'Unrecognized type') + + return module.types[node.id] + + raise NotImplementedError(f'{node} as type') + +def _not_implemented(check: Any, msg: str) -> None: + if not check: + raise NotImplementedError(msg) + +def _raise_static_error(node: Union[ast.mod, ast.stmt, ast.expr], msg: str) -> NoReturn: + raise StaticError( + f'Static error on line {node.lineno}: {msg}' + ) diff --git a/py2wasm/utils.py b/py2wasm/utils.py index c1e9b72..e93d0b5 100644 --- a/py2wasm/utils.py +++ b/py2wasm/utils.py @@ -4,7 +4,8 @@ Utility functions import ast -from py2wasm.python import Visitor +from .ourlang import OurVisitor, Module +from .python import Visitor def process(source: str, input_name: str) -> str: """ @@ -16,3 +17,12 @@ def process(source: str, input_name: str) -> str: module = visitor.visit_Module(res) return module.generate() + +def our_process(source: str, input_name: str) -> Module: + """ + Processes the python code into web assembly code + """ + res = ast.parse(source, input_name) + + our_visitor = OurVisitor() + return our_visitor.visit_Module(res) diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index e69de29..e6b179c 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.register_assert_rewrite('tests.integration.helpers')