From 89ad648f348d842d53135cbb8d94254752e86cb1 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 9 Jul 2022 14:04:40 +0200 Subject: [PATCH] Moved rendering to codestyle, parsing to parser Also, removed name argument when parsing, wasn't used. --- phasm/__main__.py | 8 +- phasm/codestyle.py | 195 ++++++ phasm/compiler.py | 109 ++-- phasm/exceptions.py | 8 + phasm/ourlang.py | 718 +--------------------- phasm/parser.py | 577 +++++++++++++++++ phasm/typing.py | 61 +- phasm/utils.py | 16 - phasm/wasm.py | 3 - pylintrc | 2 +- tests/integration/helpers.py | 16 +- tests/integration/test_fib.py | 2 +- tests/integration/test_runtime_checks.py | 2 +- tests/integration/test_simple.py | 78 ++- tests/integration/test_static_checking.py | 12 +- 15 files changed, 951 insertions(+), 856 deletions(-) create mode 100644 phasm/codestyle.py create mode 100644 phasm/exceptions.py create mode 100644 phasm/parser.py delete mode 100644 phasm/utils.py diff --git a/phasm/__main__.py b/phasm/__main__.py index 7c65724..bc94cdb 100644 --- a/phasm/__main__.py +++ b/phasm/__main__.py @@ -4,8 +4,8 @@ Functions for using this module from CLI import sys -from .utils import our_process -from .compiler import module +from .parser import phasm_parse +from .compiler import phasm_compile def main(source: str, sink: str) -> int: """ @@ -15,8 +15,8 @@ def main(source: str, sink: str) -> int: with open(source, 'r') as fil: code_py = fil.read() - our_module = our_process(code_py, source) - wasm_module = module(our_module) + our_module = phasm_parse(code_py) + wasm_module = phasm_compile(our_module) code_wat = wasm_module.to_wat() with open(sink, 'w') as fil: diff --git a/phasm/codestyle.py b/phasm/codestyle.py new file mode 100644 index 0000000..5d217f5 --- /dev/null +++ b/phasm/codestyle.py @@ -0,0 +1,195 @@ +""" +This module generates source code based on the parsed AST + +It's intented to be a "any color, as long as it's black" kind of renderer +""" +from typing import Generator + +from . import ourlang +from . import typing + +def phasm_render(inp: ourlang.Module) -> str: + """ + Public method for rendering a Phasm module into Phasm code + """ + return module(inp) + +Statements = Generator[str, None, None] + +def type_(inp: typing.TypeBase) -> str: + """ + Render: Type (name) + """ + if isinstance(inp, typing.TypeNone): + return 'None' + + if isinstance(inp, typing.TypeBool): + return 'bool' + + if isinstance(inp, typing.TypeUInt8): + return 'u8' + + if isinstance(inp, typing.TypeInt32): + return 'i32' + + if isinstance(inp, typing.TypeInt64): + return 'i64' + + if isinstance(inp, typing.TypeFloat32): + return 'f32' + + if isinstance(inp, typing.TypeFloat64): + return 'f64' + + if isinstance(inp, typing.TypeBytes): + return 'bytes' + + if isinstance(inp, typing.TypeTuple): + mems = ', '.join( + type_(x.type) + for x in inp.members + ) + + return f'({mems}, )' + + if isinstance(inp, typing.TypeStruct): + return inp.name + + raise NotImplementedError(type_, inp) + +def struct_definition(inp: typing.TypeStruct) -> str: + """ + Render: TypeStruct's definition + """ + result = f'class {inp.name}:\n' + for mem in inp.members: + result += f' {mem.name}: {type_(mem.type)}\n' + + return result + +def expression(inp: ourlang.Expression) -> str: + """ + Render: A Phasm expression + """ + if isinstance(inp, (ourlang.ConstantUInt8, ourlang.ConstantInt32, ourlang.ConstantInt64, )): + return str(inp.value) + + if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )): + # These might not round trip if the original constant + # could not fit in the given float type + return str(inp.value) + + if isinstance(inp, ourlang.VariableReference): + return str(inp.name) + + if isinstance(inp, ourlang.UnaryOp): + if ( + inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS + or inp.operator in ourlang.WEBASSEMBLY_BUILDIN_BYTES_OPS): + return f'{inp.operator}({expression(inp.right)})' + + return f'{inp.operator}{expression(inp.right)}' + + if isinstance(inp, ourlang.BinaryOp): + return f'{expression(inp.left)} {inp.operator} {expression(inp.right)}' + + if isinstance(inp, ourlang.FunctionCall): + args = ', '.join( + expression(arg) + for arg in inp.arguments + ) + + if isinstance(inp.function, ourlang.StructConstructor): + return f'{inp.function.struct.name}({args})' + + if isinstance(inp.function, ourlang.TupleConstructor): + return f'({args}, )' + + return f'{inp.function.name}({args})' + + if isinstance(inp, ourlang.AccessBytesIndex): + return f'{expression(inp.varref)}[{expression(inp.index)}]' + + if isinstance(inp, ourlang.AccessStructMember): + return f'{expression(inp.varref)}.{inp.member.name}' + + if isinstance(inp, ourlang.AccessTupleMember): + return f'{expression(inp.varref)}[{inp.member.idx}]' + + raise NotImplementedError(expression, inp) + +def statement(inp: ourlang.Statement) -> Statements: + """ + Render: A list of Phasm statements + """ + if isinstance(inp, ourlang.StatementPass): + yield 'pass' + return + + if isinstance(inp, ourlang.StatementReturn): + yield f'return {expression(inp.value)}' + return + + if isinstance(inp, ourlang.StatementIf): + yield f'if {expression(inp.test)}:' + + for stmt in inp.statements: + for line in statement(stmt): + yield f' {line}' if line else '' + + yield '' + return + + raise NotImplementedError(statement, inp) + +def function(inp: ourlang.Function) -> str: + """ + Render: Function body + + Imported functions only have "pass" as a body. Later on we might replace + this by the function documentation, if any. + """ + result = '' + if inp.exported: + result += '@exported\n' + if inp.imported: + result += '@imported\n' + + args = ', '.join( + f'{x}: {type_(y)}' + for x, y in inp.posonlyargs + ) + + result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' + + if inp.imported: + result += ' pass\n' + else: + for stmt in inp.statements: + for line in statement(stmt): + result += f' {line}\n' if line else '\n' + + return result + + +def module(inp: ourlang.Module) -> str: + """ + Render: Module + """ + result = '' + + for struct in inp.structs.values(): + if result: + result += '\n' + result += struct_definition(struct) + + for func in inp.functions.values(): + if func.lineno < 0: + # Buildin (-2) or auto generated (-1) + continue + + if result: + result += '\n' + result += function(func) + + return result diff --git a/phasm/compiler.py b/phasm/compiler.py index 979543e..4ad4bd2 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -1,7 +1,7 @@ """ This module contains the code to convert parsed Ourlang into WebAssembly code """ -from typing import Generator, Tuple +from typing import Generator from . import ourlang from . import typing @@ -9,7 +9,28 @@ from . import wasm Statements = Generator[wasm.Statement, None, None] +LOAD_STORE_TYPE_MAP = { + typing.TypeUInt8: 'i32', + typing.TypeInt32: 'i32', + typing.TypeInt64: 'i64', + typing.TypeFloat32: 'f32', + typing.TypeFloat64: 'f64', +} +""" +When generating code, we sometimes need to load or store simple values +""" + +def phasm_compile(inp: ourlang.Module) -> wasm.Module: + """ + Public method for compiling a parsed Phasm module into + a WebAssembly module + """ + return module(inp) + def type_(inp: typing.TypeBase) -> wasm.WasmType: + """ + Compile: type + """ if isinstance(inp, typing.TypeNone): return wasm.WasmTypeNone() @@ -60,6 +81,9 @@ I64_OPERATOR_MAP = { # TODO: Introduce UInt32 type } def expression(inp: ourlang.Expression) -> Statements: + """ + Compile: Any expression + """ if isinstance(inp, ourlang.ConstantUInt8): yield wasm.Statement('i32.const', str(inp.value)) return @@ -150,42 +174,40 @@ def expression(inp: ourlang.Expression) -> Statements: return if isinstance(inp, ourlang.AccessStructMember): - if isinstance(inp.member.type, typing.TypeUInt8): - mtyp = 'i32' - else: - # FIXME: Properly implement this - # inp.type.render() is also a hack that doesn't really work consistently - if not isinstance(inp.member.type, ( - typing.TypeInt32, typing.TypeFloat32, - typing.TypeInt64, typing.TypeFloat64, - )): - raise NotImplementedError - mtyp = inp.member.type.render() + mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, inp.member) yield from expression(inp.varref) yield wasm.Statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) return if isinstance(inp, ourlang.AccessTupleMember): - # FIXME: Properly implement this - # inp.type.render() is also a hack that doesn't really work consistently - if not isinstance(inp.type, ( - typing.TypeInt32, typing.TypeFloat32, - typing.TypeInt64, typing.TypeFloat64, - )): - raise NotImplementedError(inp, inp.type) + mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, inp.member) yield from expression(inp.varref) - yield wasm.Statement(inp.type.render() + '.load', 'offset=' + str(inp.member.offset)) + yield wasm.Statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) return raise NotImplementedError(expression, inp) def statement_return(inp: ourlang.StatementReturn) -> Statements: + """ + Compile: Return statement + """ yield from expression(inp.value) yield wasm.Statement('return') def statement_if(inp: ourlang.StatementIf) -> Statements: + """ + Compile: If statement + """ yield from expression(inp.test) yield wasm.Statement('if') @@ -201,6 +223,9 @@ def statement_if(inp: ourlang.StatementIf) -> Statements: yield wasm.Statement('end') def statement(inp: ourlang.Statement) -> Statements: + """ + Compile: any statement + """ if isinstance(inp, ourlang.StatementReturn): yield from statement_return(inp) return @@ -214,10 +239,16 @@ def statement(inp: ourlang.Statement) -> Statements: raise NotImplementedError(statement, inp) -def function_argument(inp: Tuple[str, typing.TypeBase]) -> wasm.Param: +def function_argument(inp: ourlang.FunctionParam) -> wasm.Param: + """ + Compile: function argument + """ return (inp[0], type_(inp[1]), ) def import_(inp: ourlang.Function) -> wasm.Import: + """ + Compile: imported function + """ assert inp.imported return wasm.Import( @@ -232,6 +263,9 @@ def import_(inp: ourlang.Function) -> wasm.Import: ) def function(inp: ourlang.Function) -> wasm.Function: + """ + Compile: function + """ assert not inp.imported if isinstance(inp, ourlang.TupleConstructor): @@ -269,6 +303,9 @@ def function(inp: ourlang.Function) -> wasm.Function: ) def module(inp: ourlang.Module) -> wasm.Module: + """ + Compile: module + """ result = wasm.Module() result.imports = [ @@ -350,17 +387,15 @@ def _generate_tuple_constructor(inp: ourlang.TupleConstructor) -> Statements: yield wasm.Statement('local.set', '$___new_reference___addr') for member in inp.tuple.members: - # FIXME: Properly implement this - # inp.type.render() is also a hack that doesn't really work consistently - if not isinstance(member.type, ( - typing.TypeInt32, typing.TypeFloat32, - typing.TypeInt64, typing.TypeFloat64, - )): - raise NotImplementedError + mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, member) yield wasm.Statement('local.get', '$___new_reference___addr') yield wasm.Statement('local.get', f'$arg{member.idx}') - yield wasm.Statement(f'{member.type.render()}.store', 'offset=' + str(member.offset)) + yield wasm.Statement(f'{mtyp}.store', 'offset=' + str(member.offset)) yield wasm.Statement('local.get', '$___new_reference___addr') @@ -371,17 +406,11 @@ def _generate_struct_constructor(inp: ourlang.StructConstructor) -> Statements: yield wasm.Statement('local.set', '$___new_reference___addr') for member in inp.struct.members: - if isinstance(member.type, typing.TypeUInt8): - mtyp = 'i32' - else: - # FIXME: Properly implement this - # inp.type.render() is also a hack that doesn't really work consistently - if not isinstance(member.type, ( - typing.TypeInt32, typing.TypeFloat32, - typing.TypeInt64, typing.TypeFloat64, - )): - raise NotImplementedError - mtyp = member.type.render() + mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, member) yield wasm.Statement('local.get', '$___new_reference___addr') yield wasm.Statement('local.get', f'${member.name}') diff --git a/phasm/exceptions.py b/phasm/exceptions.py new file mode 100644 index 0000000..b459c22 --- /dev/null +++ b/phasm/exceptions.py @@ -0,0 +1,8 @@ +""" +Exceptions for the phasm compiler +""" + +class StaticError(Exception): + """ + An error found during static analysis + """ diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 64bdedc..1903328 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -1,13 +1,12 @@ """ Contains the syntax tree for ourlang """ -from typing import Any, Dict, List, Optional, NoReturn, Union, Tuple - -import ast +from typing import Dict, List, Tuple from typing_extensions import Final WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) +WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) from .typing import ( TypeBase, @@ -32,14 +31,6 @@ class Expression: def __init__(self, type_: TypeBase) -> None: self.type = type_ - def render(self) -> str: - """ - Renders the expression back to source code format - - This'll look like Python code. - """ - raise NotImplementedError(self, 'render') - class Constant(Expression): """ An constant value expression within a statement @@ -58,9 +49,6 @@ class ConstantUInt8(Constant): super().__init__(type_) self.value = value - def render(self) -> str: - return str(self.value) - class ConstantInt32(Constant): """ An Int32 constant value expression within a statement @@ -73,9 +61,6 @@ class ConstantInt32(Constant): super().__init__(type_) self.value = value - def render(self) -> str: - return str(self.value) - class ConstantInt64(Constant): """ An Int64 constant value expression within a statement @@ -88,9 +73,6 @@ class ConstantInt64(Constant): super().__init__(type_) self.value = value - def render(self) -> str: - return str(self.value) - class ConstantFloat32(Constant): """ An Float32 constant value expression within a statement @@ -103,9 +85,6 @@ class ConstantFloat32(Constant): super().__init__(type_) self.value = value - def render(self) -> str: - return str(self.value) - class ConstantFloat64(Constant): """ An Float64 constant value expression within a statement @@ -118,9 +97,6 @@ class ConstantFloat64(Constant): super().__init__(type_) self.value = value - def render(self) -> str: - return str(self.value) - class VariableReference(Expression): """ An variable reference expression within a statement @@ -133,8 +109,20 @@ class VariableReference(Expression): super().__init__(type_) self.name = name - def render(self) -> str: - return str(self.name) +class UnaryOp(Expression): + """ + A unary operator expression within a statement + """ + __slots__ = ('operator', 'right', ) + + operator: str + right: Expression + + def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None: + super().__init__(type_) + + self.operator = operator + self.right = right class BinaryOp(Expression): """ @@ -153,30 +141,6 @@ class BinaryOp(Expression): self.left = left self.right = right - def render(self) -> str: - return f'{self.left.render()} {self.operator} {self.right.render()}' - -class UnaryOp(Expression): - """ - A unary operator expression within a statement - """ - __slots__ = ('operator', 'right', ) - - operator: str - right: Expression - - def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None: - super().__init__(type_) - - self.operator = operator - self.right = right - - def render(self) -> str: - if self.operator in WEBASSEMBLY_BUILDIN_FLOAT_OPS or self.operator == 'len': - return f'{self.operator}({self.right.render()})' - - return f'{self.operator}{self.right.render()}' - class FunctionCall(Expression): """ A function call expression within a statement @@ -192,20 +156,6 @@ class FunctionCall(Expression): self.function = function self.arguments = [] - def render(self) -> str: - args = ', '.join( - arg.render() - for arg in self.arguments - ) - - if isinstance(self.function, StructConstructor): - return f'{self.function.struct.name}({args})' - - if isinstance(self.function, TupleConstructor): - return f'({args}, )' - - return f'{self.function.name}({args})' - class AccessBytesIndex(Expression): """ Access a bytes index for reading @@ -221,9 +171,6 @@ class AccessBytesIndex(Expression): self.varref = varref self.index = index - def render(self) -> str: - return f'{self.varref.render()}[{self.index.render()}]' - class AccessStructMember(Expression): """ Access a struct member for reading of writing @@ -239,9 +186,6 @@ class AccessStructMember(Expression): self.varref = varref self.member = member - def render(self) -> str: - return f'{self.varref.render()}.{self.member.name}' - class AccessTupleMember(Expression): """ Access a tuple member for reading of writing @@ -257,22 +201,17 @@ class AccessTupleMember(Expression): self.varref = varref self.member = member - def render(self) -> str: - return f'{self.varref.render()}[{self.member.idx}]' - class Statement: """ A statement within a function """ __slots__ = () - def render(self) -> List[str]: - """ - Renders the type back to source code format - - This'll look like Python code. - """ - raise NotImplementedError(self, 'render') +class StatementPass(Statement): + """ + A pass statement + """ + __slots__ = () class StatementReturn(Statement): """ @@ -283,14 +222,6 @@ class StatementReturn(Statement): def __init__(self, value: Expression) -> None: self.value = value - def render(self) -> List[str]: - """ - Renders the type back to source code format - - This'll look like Python code. - """ - return [f'return {self.value.render()}'] - class StatementIf(Statement): """ An if statement within a function @@ -306,32 +237,7 @@ class StatementIf(Statement): self.statements = [] self.else_statements = [] - def render(self) -> List[str]: - """ - Renders the type back to source code format - - This'll look like Python code. - """ - result = [f'if {self.test.render()}:'] - - for stmt in self.statements: - result.extend( - f' {line}' if line else '' - for line in stmt.render() - ) - - result.append('') - - return result - -class StatementPass(Statement): - """ - A pass statement - """ - __slots__ = () - - def render(self) -> List[str]: - return ['pass'] +FunctionParam = Tuple[str, TypeBase] class Function: """ @@ -345,7 +251,7 @@ class Function: imported: bool statements: List[Statement] returns: TypeBase - posonlyargs: List[Tuple[str, TypeBase]] + posonlyargs: List[FunctionParam] def __init__(self, name: str, lineno: int) -> None: self.name = name @@ -356,35 +262,12 @@ class Function: self.returns = TypeNone() self.posonlyargs = [] - def render(self) -> str: - """ - Renders the function back to source code format - - This'll look like Python code. - """ - statements = self.statements - - result = '' - if self.exported: - result += '@exported\n' - if self.imported: - result += '@imported\n' - statements = [StatementPass()] - - args = ', '.join( - f'{x}: {y.render()}' - for x, y in self.posonlyargs - ) - - result += f'def {self.name}({args}) -> {self.returns.render()}:\n' - for stmt in statements: - for line in stmt.render(): - result += f' {line}\n' if line else '\n' - return result - class StructConstructor(Function): """ The constructor method for a struct + + A function will generated to instantiate a struct. The arguments + will be the defaults """ __slots__ = ('struct', ) @@ -442,552 +325,3 @@ class Module: } self.functions = {} self.structs = {} - - def render(self) -> str: - """ - Renders the module back to source code format - - This'll look like Python code. - """ - result = '' - - for struct in self.structs.values(): - if result: - result += '\n' - result += struct.render_definition() - - for function in self.functions.values(): - if function.lineno < 0: - # Buildin (-2) or auto generated (-1) - continue - - if result: - result += '\n' - result += function.render() - - return result - -class StaticError(Exception): - """ - An error found during static analysis - """ - -OurLocals = Dict[str, TypeBase] - -class OurVisitor: - """ - Class to visit a Python syntax tree and create an ourlang syntax tree - """ - - # pylint: disable=C0103,C0116,C0301,R0201,R0912 - - def __init__(self) -> None: - pass - - def visit_Module(self, node: ast.Module) -> Module: - module = Module() - - _not_implemented(not node.type_ignores, 'Module.type_ignores') - - # Second pass for the types - - for stmt in node.body: - res = self.pre_visit_Module_stmt(module, stmt) - - if isinstance(res, Function): - if res.name in module.functions: - raise StaticError( - f'{res.name} already defined on line {module.functions[res.name].lineno}' - ) - - module.functions[res.name] = res - - if isinstance(res, TypeStruct): - if res.name in module.structs: - raise StaticError( - f'{res.name} already defined on line {module.structs[res.name].lineno}' - ) - - module.structs[res.name] = res - constructor = StructConstructor(res) - module.functions[constructor.name] = constructor - - # Second pass for the function bodies - - for stmt in node.body: - self.visit_Module_stmt(module, stmt) - - return module - - def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, TypeStruct]: - if isinstance(node, ast.FunctionDef): - return self.pre_visit_Module_FunctionDef(module, node) - - if isinstance(node, ast.ClassDef): - return self.pre_visit_Module_ClassDef(module, node) - - raise NotImplementedError(f'{node} on Module') - - def pre_visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> Function: - function = Function(node.name, node.lineno) - - _not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs') - - for arg in node.args.args: - if not arg.annotation: - _raise_static_error(node, 'Type is required') - - function.posonlyargs.append(( - arg.arg, - self.visit_type(module, arg.annotation), - )) - - _not_implemented(not node.args.vararg, 'FunctionDef.args.vararg') - _not_implemented(not node.args.kwonlyargs, 'FunctionDef.args.kwonlyargs') - _not_implemented(not node.args.kw_defaults, 'FunctionDef.args.kw_defaults') - _not_implemented(not node.args.kwarg, 'FunctionDef.args.kwarg') - _not_implemented(not node.args.defaults, 'FunctionDef.args.defaults') - - # Do stmts at the end so we have the return value - - for decorator in node.decorator_list: - if not isinstance(decorator, ast.Name): - _raise_static_error(decorator, 'Function decorators must be string') - if not isinstance(decorator.ctx, ast.Load): - _raise_static_error(decorator, 'Must be load context') - _not_implemented(decorator.id in ('exported', 'imported'), 'Custom decorators') - - if decorator.id == 'exported': - function.exported = True - else: - function.imported = True - - if node.returns: - function.returns = self.visit_type(module, node.returns) - - _not_implemented(not node.type_comment, 'FunctionDef.type_comment') - - return function - - def pre_visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> TypeStruct: - struct = TypeStruct(node.name, node.lineno) - - _not_implemented(not node.bases, 'ClassDef.bases') - _not_implemented(not node.keywords, 'ClassDef.keywords') - _not_implemented(not node.decorator_list, 'ClassDef.decorator_list') - - offset = 0 - - for stmt in node.body: - if not isinstance(stmt, ast.AnnAssign): - raise NotImplementedError(f'Class with {stmt} nodes') - - if not isinstance(stmt.target, ast.Name): - raise NotImplementedError('Class with default values') - - if not stmt.value is None: - raise NotImplementedError('Class with default values') - - if stmt.simple != 1: - raise NotImplementedError('Class with non-simple arguments') - - member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) - - struct.members.append(member) - offset += member.type.alloc_size() - - return struct - - def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None: - if isinstance(node, ast.FunctionDef): - self.visit_Module_FunctionDef(module, node) - return - - if isinstance(node, ast.ClassDef): - return - - raise NotImplementedError(f'{node} on Module') - - def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: - function = module.functions[node.name] - - our_locals = dict(function.posonlyargs) - - for stmt in node.body: - function.statements.append( - self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) - ) - - def visit_Module_FunctionDef_stmt(self, module: Module, function: Function, our_locals: OurLocals, node: ast.stmt) -> Statement: - if isinstance(node, ast.Return): - if node.value is None: - # TODO: Implement methods without return values - _raise_static_error(node, 'Return must have an argument') - - return StatementReturn( - self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) - ) - - if isinstance(node, ast.If): - result = StatementIf( - self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) - ) - - for stmt in node.body: - result.statements.append( - self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) - ) - - for stmt in node.orelse: - result.else_statements.append( - self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) - ) - - return result - - if isinstance(node, ast.Pass): - return StatementPass() - - raise NotImplementedError(f'{node} as stmt in FunctionDef') - - def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression: - if isinstance(node, ast.BinOp): - if isinstance(node.op, ast.Add): - operator = '+' - elif isinstance(node.op, ast.Sub): - operator = '-' - elif isinstance(node.op, ast.Mult): - operator = '*' - else: - raise NotImplementedError(f'Operator {node.op}') - - # Assume the type doesn't change when descending into a binary operator - # e.g. you can do `"hello" * 3` with the code below (yet) - - return BinaryOp( - exp_type, - operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), - ) - - if isinstance(node, ast.UnaryOp): - if isinstance(node.op, ast.UAdd): - operator = '+' - elif isinstance(node.op, ast.USub): - operator = '-' - else: - raise NotImplementedError(f'Operator {node.op}') - - return UnaryOp( - exp_type, - operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), - ) - - if isinstance(node, ast.Compare): - if 1 < len(node.ops): - raise NotImplementedError('Multiple operators') - - if isinstance(node.ops[0], ast.Gt): - operator = '>' - elif isinstance(node.ops[0], ast.Eq): - operator = '==' - elif isinstance(node.ops[0], ast.Lt): - operator = '<' - else: - raise NotImplementedError(f'Operator {node.ops}') - - # Assume the type doesn't change when descending into a binary operator - # e.g. you can do `"hello" * 3` with the code below (yet) - - return BinaryOp( - exp_type, - operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), - ) - - if isinstance(node, ast.Call): - return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) - - if isinstance(node, ast.Constant): - return self.visit_Module_FunctionDef_Constant( - module, function, exp_type, node, - ) - - if isinstance(node, ast.Attribute): - return self.visit_Module_FunctionDef_Attribute( - module, function, our_locals, exp_type, node, - ) - - if isinstance(node, ast.Subscript): - return self.visit_Module_FunctionDef_Subscript( - module, function, our_locals, exp_type, node, - ) - - if isinstance(node, ast.Name): - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if node.id not in our_locals: - _raise_static_error(node, 'Undefined variable') - - act_type = our_locals[node.id] - if exp_type != act_type: - _raise_static_error(node, f'Expected {exp_type.render()}, {node.id} is actually {act_type.render()}') - - return VariableReference(act_type, node.id) - - if isinstance(node, ast.Tuple): - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if not isinstance(exp_type, TypeTuple): - _raise_static_error(node, f'Expression is expecting a {exp_type.render()}, not a tuple') - - if len(exp_type.members) != len(node.elts): - _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') - - tuple_constructor = TupleConstructor(exp_type) - - func = module.functions[tuple_constructor.name] - - result = FunctionCall(func) - result.arguments = [ - self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) - for arg_node, mem in zip(node.elts, exp_type.members) - ] - return result - - raise NotImplementedError(f'{node} as expr in FunctionDef') - - def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[FunctionCall, UnaryOp]: - if node.keywords: - _raise_static_error(node, 'Keyword calling not supported') # Yet? - - if not isinstance(node.func, ast.Name): - raise NotImplementedError(f'Calling methods that are not a name {node.func}') - if not isinstance(node.func.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if node.func.id in module.structs: - struct = module.structs[node.func.id] - struct_constructor = StructConstructor(struct) - - func = module.functions[struct_constructor.name] - elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: - if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - - if 1 != len(node.args): - _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') - - return UnaryOp( - exp_type, - 'sqrt', - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), - ) - elif node.func.id == 'len': - if not isinstance(exp_type, TypeInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - - if 1 != len(node.args): - _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') - - return UnaryOp( - exp_type, - 'len', - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), - ) - else: - if node.func.id not in module.functions: - _raise_static_error(node, 'Call to undefined function') - - func = module.functions[node.func.id] - - if func.returns != exp_type: - _raise_static_error(node, f'Expected {exp_type.render()}, {func.name} actually returns {func.returns.render()}') - - if len(func.posonlyargs) != len(node.args): - _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') - - result = FunctionCall(func) - result.arguments.extend( - self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) - for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) - ) - return result - - def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression: - if not isinstance(node.value, ast.Name): - _raise_static_error(node, 'Must reference a name') - - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if not node.value.id in our_locals: - _raise_static_error(node, f'Undefined variable {node.value.id}') - - node_typ = our_locals[node.value.id] - if not isinstance(node_typ, TypeStruct): - _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') - - member = node_typ.get_member(node.attr) - if member is None: - _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') - - if exp_type != member.type: - _raise_static_error(node, f'Expected {exp_type.render()}, {node.value.id}.{member.name} is actually {member.type.render()}') - - return AccessStructMember( - VariableReference(node_typ, node.value.id), - member, - ) - - def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression: - if not isinstance(node.value, ast.Name): - _raise_static_error(node, 'Must reference a name') - - if not isinstance(node.slice, ast.Index): - _raise_static_error(node, 'Must subscript using an index') - - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if not node.value.id in our_locals: - _raise_static_error(node, f'Undefined variable {node.value.id}') - - node_typ = our_locals[node.value.id] - - slice_expr = self.visit_Module_FunctionDef_expr( - module, function, our_locals, module.types['i32'], node.slice.value, - ) - - if isinstance(node_typ, TypeBytes): - return AccessBytesIndex( - module.types['u8'], - VariableReference(node_typ, node.value.id), - slice_expr, - ) - - if isinstance(node_typ, TypeTuple): - if not isinstance(slice_expr, ConstantInt32): - _raise_static_error(node, 'Must subscript using a constant index') - - idx = slice_expr.value - - if len(node_typ.members) <= idx: - _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') - - member = node_typ.members[idx] - if exp_type != member.type: - _raise_static_error(node, f'Expected {exp_type.render()}, {node.value.id}[{idx}] is actually {member.type.render()}') - - return AccessTupleMember( - VariableReference(node_typ, node.value.id), - member, - ) - - _raise_static_error(node, f'Cannot take index of {node_typ.render()} {node.value.id}') - - def visit_Module_FunctionDef_Constant(self, module: Module, function: Function, exp_type: TypeBase, node: ast.Constant) -> Expression: - del module - del function - - _not_implemented(node.kind is None, 'Constant.kind') - - if isinstance(exp_type, TypeUInt8): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - # FIXME: Range check - - return ConstantUInt8(exp_type, node.value) - - if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - # FIXME: Range check - - return ConstantInt32(exp_type, node.value) - - if isinstance(exp_type, TypeInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - # FIXME: Range check - - return ConstantInt64(exp_type, node.value) - - if isinstance(exp_type, TypeFloat32): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat32(exp_type, node.value) - - if isinstance(exp_type, TypeFloat64): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat64(exp_type, node.value) - - raise NotImplementedError(f'{node} as const for type {exp_type.render()}') - - def visit_type(self, module: Module, node: ast.expr) -> TypeBase: - if isinstance(node, ast.Constant): - if node.value is None: - return module.types['None'] - - _raise_static_error(node, f'Unrecognized type {node.value}') - - if isinstance(node, ast.Name): - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if node.id in module.types: - return module.types[node.id] - - if node.id in module.structs: - return module.structs[node.id] - - _raise_static_error(node, f'Unrecognized type {node.id}') - - if isinstance(node, ast.Tuple): - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - result = TypeTuple() - - offset = 0 - - for idx, elt in enumerate(node.elts): - member = TypeTupleMember(idx, self.visit_type(module, elt), offset) - - result.members.append(member) - offset += member.type.alloc_size() - - key = result.render_internal_name() - - if key not in module.types: - module.types[key] = result - constructor = TupleConstructor(result) - module.functions[constructor.name] = constructor - - return module.types[key] - - raise NotImplementedError(f'{node} as type') - -def _not_implemented(check: Any, msg: str) -> None: - if not check: - raise NotImplementedError(msg) - -def _raise_static_error(node: Union[ast.mod, ast.stmt, ast.expr], msg: str) -> NoReturn: - raise StaticError( - f'Static error on line {node.lineno}: {msg}' - ) diff --git a/phasm/parser.py b/phasm/parser.py new file mode 100644 index 0000000..dda06e6 --- /dev/null +++ b/phasm/parser.py @@ -0,0 +1,577 @@ +""" +Parses the source code from the plain text into a syntax tree +""" +from typing import Any, Dict, NoReturn, Union + +import ast + +from .typing import ( + TypeBase, + TypeUInt8, + TypeInt32, + TypeInt64, + TypeFloat32, + TypeFloat64, + TypeBytes, + TypeStruct, + TypeStructMember, + TypeTuple, + TypeTupleMember, +) + +from . import codestyle +from .exceptions import StaticError +from .ourlang import ( + WEBASSEMBLY_BUILDIN_FLOAT_OPS, + + Module, + Function, + + Expression, + AccessBytesIndex, AccessStructMember, AccessTupleMember, + BinaryOp, + ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64, ConstantUInt8, + FunctionCall, + StructConstructor, TupleConstructor, + UnaryOp, VariableReference, + + Statement, + StatementIf, StatementPass, StatementReturn, +) + +def phasm_parse(source: str) -> Module: + """ + Public method for parsing Phasm code into a Phasm Module + """ + res = ast.parse(source, '') + + our_visitor = OurVisitor() + return our_visitor.visit_Module(res) + +OurLocals = Dict[str, TypeBase] + +class OurVisitor: + """ + Class to visit a Python syntax tree and create an ourlang syntax tree + + We're (ab)using the Python AST parser to give us a leg up + + At some point, we may deviate from Python syntax. If nothing else, + we probably won't keep up with the Python syntax changes. + """ + + # pylint: disable=C0103,C0116,C0301,R0201,R0912 + + def __init__(self) -> None: + pass + + def visit_Module(self, node: ast.Module) -> Module: + module = Module() + + _not_implemented(not node.type_ignores, 'Module.type_ignores') + + # Second pass for the types + + for stmt in node.body: + res = self.pre_visit_Module_stmt(module, stmt) + + if isinstance(res, Function): + if res.name in module.functions: + raise StaticError( + f'{res.name} already defined on line {module.functions[res.name].lineno}' + ) + + module.functions[res.name] = res + + if isinstance(res, TypeStruct): + if res.name in module.structs: + raise StaticError( + f'{res.name} already defined on line {module.structs[res.name].lineno}' + ) + + module.structs[res.name] = res + constructor = StructConstructor(res) + module.functions[constructor.name] = constructor + + # Second pass for the function bodies + + for stmt in node.body: + self.visit_Module_stmt(module, stmt) + + return module + + def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, TypeStruct]: + if isinstance(node, ast.FunctionDef): + return self.pre_visit_Module_FunctionDef(module, node) + + if isinstance(node, ast.ClassDef): + return self.pre_visit_Module_ClassDef(module, node) + + raise NotImplementedError(f'{node} on Module') + + def pre_visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> Function: + function = Function(node.name, node.lineno) + + _not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs') + + for arg in node.args.args: + if not arg.annotation: + _raise_static_error(node, 'Type is required') + + function.posonlyargs.append(( + arg.arg, + self.visit_type(module, arg.annotation), + )) + + _not_implemented(not node.args.vararg, 'FunctionDef.args.vararg') + _not_implemented(not node.args.kwonlyargs, 'FunctionDef.args.kwonlyargs') + _not_implemented(not node.args.kw_defaults, 'FunctionDef.args.kw_defaults') + _not_implemented(not node.args.kwarg, 'FunctionDef.args.kwarg') + _not_implemented(not node.args.defaults, 'FunctionDef.args.defaults') + + # Do stmts at the end so we have the return value + + for decorator in node.decorator_list: + if not isinstance(decorator, ast.Name): + _raise_static_error(decorator, 'Function decorators must be string') + if not isinstance(decorator.ctx, ast.Load): + _raise_static_error(decorator, 'Must be load context') + _not_implemented(decorator.id in ('exported', 'imported'), 'Custom decorators') + + if decorator.id == 'exported': + function.exported = True + else: + function.imported = True + + if node.returns: + function.returns = self.visit_type(module, node.returns) + + _not_implemented(not node.type_comment, 'FunctionDef.type_comment') + + return function + + def pre_visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> TypeStruct: + struct = TypeStruct(node.name, node.lineno) + + _not_implemented(not node.bases, 'ClassDef.bases') + _not_implemented(not node.keywords, 'ClassDef.keywords') + _not_implemented(not node.decorator_list, 'ClassDef.decorator_list') + + offset = 0 + + for stmt in node.body: + if not isinstance(stmt, ast.AnnAssign): + raise NotImplementedError(f'Class with {stmt} nodes') + + if not isinstance(stmt.target, ast.Name): + raise NotImplementedError('Class with default values') + + if not stmt.value is None: + raise NotImplementedError('Class with default values') + + if stmt.simple != 1: + raise NotImplementedError('Class with non-simple arguments') + + member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) + + struct.members.append(member) + offset += member.type.alloc_size() + + return struct + + def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None: + if isinstance(node, ast.FunctionDef): + self.visit_Module_FunctionDef(module, node) + return + + if isinstance(node, ast.ClassDef): + return + + raise NotImplementedError(f'{node} on Module') + + def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: + function = module.functions[node.name] + + our_locals = dict(function.posonlyargs) + + for stmt in node.body: + function.statements.append( + self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) + ) + + def visit_Module_FunctionDef_stmt(self, module: Module, function: Function, our_locals: OurLocals, node: ast.stmt) -> Statement: + if isinstance(node, ast.Return): + if node.value is None: + # TODO: Implement methods without return values + _raise_static_error(node, 'Return must have an argument') + + return StatementReturn( + self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) + ) + + if isinstance(node, ast.If): + result = StatementIf( + self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) + ) + + for stmt in node.body: + result.statements.append( + self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) + ) + + for stmt in node.orelse: + result.else_statements.append( + self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) + ) + + return result + + if isinstance(node, ast.Pass): + return StatementPass() + + raise NotImplementedError(f'{node} as stmt in FunctionDef') + + def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression: + if isinstance(node, ast.BinOp): + if isinstance(node.op, ast.Add): + operator = '+' + elif isinstance(node.op, ast.Sub): + operator = '-' + elif isinstance(node.op, ast.Mult): + operator = '*' + else: + raise NotImplementedError(f'Operator {node.op}') + + # Assume the type doesn't change when descending into a binary operator + # e.g. you can do `"hello" * 3` with the code below (yet) + + return BinaryOp( + exp_type, + operator, + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), + ) + + if isinstance(node, ast.UnaryOp): + if isinstance(node.op, ast.UAdd): + operator = '+' + elif isinstance(node.op, ast.USub): + operator = '-' + else: + raise NotImplementedError(f'Operator {node.op}') + + return UnaryOp( + exp_type, + operator, + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), + ) + + if isinstance(node, ast.Compare): + if 1 < len(node.ops): + raise NotImplementedError('Multiple operators') + + if isinstance(node.ops[0], ast.Gt): + operator = '>' + elif isinstance(node.ops[0], ast.Eq): + operator = '==' + elif isinstance(node.ops[0], ast.Lt): + operator = '<' + else: + raise NotImplementedError(f'Operator {node.ops}') + + # Assume the type doesn't change when descending into a binary operator + # e.g. you can do `"hello" * 3` with the code below (yet) + + return BinaryOp( + exp_type, + operator, + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), + ) + + if isinstance(node, ast.Call): + return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) + + if isinstance(node, ast.Constant): + return self.visit_Module_FunctionDef_Constant( + module, function, exp_type, node, + ) + + if isinstance(node, ast.Attribute): + return self.visit_Module_FunctionDef_Attribute( + module, function, our_locals, exp_type, node, + ) + + if isinstance(node, ast.Subscript): + return self.visit_Module_FunctionDef_Subscript( + module, function, our_locals, exp_type, node, + ) + + if isinstance(node, ast.Name): + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if node.id not in our_locals: + _raise_static_error(node, 'Undefined variable') + + act_type = our_locals[node.id] + if exp_type != act_type: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') + + return VariableReference(act_type, node.id) + + if isinstance(node, ast.Tuple): + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if not isinstance(exp_type, TypeTuple): + _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple') + + if len(exp_type.members) != len(node.elts): + _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') + + tuple_constructor = TupleConstructor(exp_type) + + func = module.functions[tuple_constructor.name] + + result = FunctionCall(func) + result.arguments = [ + self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) + for arg_node, mem in zip(node.elts, exp_type.members) + ] + return result + + raise NotImplementedError(f'{node} as expr in FunctionDef') + + def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[FunctionCall, UnaryOp]: + if node.keywords: + _raise_static_error(node, 'Keyword calling not supported') # Yet? + + if not isinstance(node.func, ast.Name): + raise NotImplementedError(f'Calling methods that are not a name {node.func}') + if not isinstance(node.func.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if node.func.id in module.structs: + struct = module.structs[node.func.id] + struct_constructor = StructConstructor(struct) + + func = module.functions[struct_constructor.name] + elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )): + _raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}') + + if 1 != len(node.args): + _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') + + return UnaryOp( + exp_type, + 'sqrt', + self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), + ) + elif node.func.id == 'len': + if not isinstance(exp_type, TypeInt32): + _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') + + if 1 != len(node.args): + _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') + + return UnaryOp( + exp_type, + 'len', + self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), + ) + else: + if node.func.id not in module.functions: + _raise_static_error(node, 'Call to undefined function') + + func = module.functions[node.func.id] + + if func.returns != exp_type: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') + + if len(func.posonlyargs) != len(node.args): + _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') + + result = FunctionCall(func) + result.arguments.extend( + self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) + for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) + ) + return result + + def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression: + del module + del function + + if not isinstance(node.value, ast.Name): + _raise_static_error(node, 'Must reference a name') + + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if not node.value.id in our_locals: + _raise_static_error(node, f'Undefined variable {node.value.id}') + + node_typ = our_locals[node.value.id] + if not isinstance(node_typ, TypeStruct): + _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') + + member = node_typ.get_member(node.attr) + if member is None: + _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') + + if exp_type != member.type: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') + + return AccessStructMember( + VariableReference(node_typ, node.value.id), + member, + ) + + def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression: + if not isinstance(node.value, ast.Name): + _raise_static_error(node, 'Must reference a name') + + if not isinstance(node.slice, ast.Index): + _raise_static_error(node, 'Must subscript using an index') + + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if not node.value.id in our_locals: + _raise_static_error(node, f'Undefined variable {node.value.id}') + + node_typ = our_locals[node.value.id] + + slice_expr = self.visit_Module_FunctionDef_expr( + module, function, our_locals, module.types['i32'], node.slice.value, + ) + + if isinstance(node_typ, TypeBytes): + return AccessBytesIndex( + module.types['u8'], + VariableReference(node_typ, node.value.id), + slice_expr, + ) + + if isinstance(node_typ, TypeTuple): + if not isinstance(slice_expr, ConstantInt32): + _raise_static_error(node, 'Must subscript using a constant index') + + idx = slice_expr.value + + if len(node_typ.members) <= idx: + _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') + + member = node_typ.members[idx] + if exp_type != member.type: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(member.type)}') + + return AccessTupleMember( + VariableReference(node_typ, node.value.id), + member, + ) + + _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') + + def visit_Module_FunctionDef_Constant(self, module: Module, function: Function, exp_type: TypeBase, node: ast.Constant) -> Expression: + del module + del function + + _not_implemented(node.kind is None, 'Constant.kind') + + if isinstance(exp_type, TypeUInt8): + if not isinstance(node.value, int): + _raise_static_error(node, 'Expected integer value') + + # FIXME: Range check + + return ConstantUInt8(exp_type, node.value) + + if isinstance(exp_type, TypeInt32): + if not isinstance(node.value, int): + _raise_static_error(node, 'Expected integer value') + + # FIXME: Range check + + return ConstantInt32(exp_type, node.value) + + if isinstance(exp_type, TypeInt64): + if not isinstance(node.value, int): + _raise_static_error(node, 'Expected integer value') + + # FIXME: Range check + + return ConstantInt64(exp_type, node.value) + + if isinstance(exp_type, TypeFloat32): + if not isinstance(node.value, (float, int, )): + _raise_static_error(node, 'Expected float value') + + # FIXME: Range check + + return ConstantFloat32(exp_type, node.value) + + if isinstance(exp_type, TypeFloat64): + if not isinstance(node.value, (float, int, )): + _raise_static_error(node, 'Expected float value') + + # FIXME: Range check + + return ConstantFloat64(exp_type, node.value) + + raise NotImplementedError(f'{node} as const for type {exp_type}') + + def visit_type(self, module: Module, node: ast.expr) -> TypeBase: + if isinstance(node, ast.Constant): + if node.value is None: + return module.types['None'] + + _raise_static_error(node, f'Unrecognized type {node.value}') + + if isinstance(node, ast.Name): + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + if node.id in module.types: + return module.types[node.id] + + if node.id in module.structs: + return module.structs[node.id] + + _raise_static_error(node, f'Unrecognized type {node.id}') + + if isinstance(node, ast.Tuple): + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + result = TypeTuple() + + offset = 0 + + for idx, elt in enumerate(node.elts): + member = TypeTupleMember(idx, self.visit_type(module, elt), offset) + + result.members.append(member) + offset += member.type.alloc_size() + + key = result.render_internal_name() + + if key not in module.types: + module.types[key] = result + constructor = TupleConstructor(result) + module.functions[constructor.name] = constructor + + return module.types[key] + + raise NotImplementedError(f'{node} as type') + +def _not_implemented(check: Any, msg: str) -> None: + if not check: + raise NotImplementedError(msg) + +def _raise_static_error(node: Union[ast.mod, ast.stmt, ast.expr], msg: str) -> NoReturn: + raise StaticError( + f'Static error on line {node.lineno}: {msg}' + ) diff --git a/phasm/typing.py b/phasm/typing.py index 4252fdb..bca76b2 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -9,14 +9,6 @@ class TypeBase: """ __slots__ = () - def render(self) -> str: - """ - Renders the type back to source code format - - This'll look like Python code. - """ - raise NotImplementedError(self, 'render') - def alloc_size(self) -> int: """ When allocating this type in memory, how many bytes do we need to reserve? @@ -29,27 +21,18 @@ class TypeNone(TypeBase): """ __slots__ = () - def render(self) -> str: - return 'None' - class TypeBool(TypeBase): """ The boolean type """ __slots__ = () - def render(self) -> str: - return 'bool' - class TypeUInt8(TypeBase): """ The Integer type, unsigned and 8 bits wide """ __slots__ = () - def render(self) -> str: - return 'u8' - def alloc_size(self) -> int: return 4 # Int32 under the hood @@ -59,9 +42,6 @@ class TypeInt32(TypeBase): """ __slots__ = () - def render(self) -> str: - return 'i32' - def alloc_size(self) -> int: return 4 @@ -71,9 +51,6 @@ class TypeInt64(TypeBase): """ __slots__ = () - def render(self) -> str: - return 'i64' - def alloc_size(self) -> int: return 8 @@ -83,9 +60,6 @@ class TypeFloat32(TypeBase): """ __slots__ = () - def render(self) -> str: - return 'f32' - def alloc_size(self) -> int: return 4 @@ -95,9 +69,6 @@ class TypeFloat64(TypeBase): """ __slots__ = () - def render(self) -> str: - return 'f64' - def alloc_size(self) -> int: return 8 @@ -107,9 +78,6 @@ class TypeBytes(TypeBase): """ __slots__ = () - def render(self) -> str: - return 'bytes' - class TypeTupleMember: """ Represents a tuple member @@ -130,12 +98,11 @@ class TypeTuple(TypeBase): def __init__(self) -> None: self.members = [] - def render(self) -> str: - mems = ', '.join(x.type.render() for x in self.members) - return f'({mems}, )' - def render_internal_name(self) -> str: - mems = '@'.join(x.type.render() for x in self.members) + """ + Generates an internal name for this tuple + """ + mems = '@'.join('?' for x in self.members) assert ' ' not in mems, 'Not implement yet: subtuples' return f'tuple@{mems}' @@ -179,26 +146,6 @@ class TypeStruct(TypeBase): return None - def render(self) -> str: - """ - Renders the type back to source code format - - This'll look like Python code. - """ - return self.name - - def render_definition(self) -> str: - """ - Renders the definition back to source code format - - This'll look like Python code. - """ - result = f'class {self.name}:\n' - for mem in self.members: - result += f' {mem.name}: {mem.type.render()}\n' - - return result - def alloc_size(self) -> int: return sum( x.type.alloc_size() diff --git a/phasm/utils.py b/phasm/utils.py deleted file mode 100644 index cc9e0a4..0000000 --- a/phasm/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Utility functions -""" - -import ast - -from .ourlang import OurVisitor, Module - -def our_process(source: str, input_name: str) -> Module: - """ - Processes the python code into web assembly code - """ - res = ast.parse(source, input_name) - - our_visitor = OurVisitor() - return our_visitor.visit_Module(res) diff --git a/phasm/wasm.py b/phasm/wasm.py index 3e5cff4..5fa2506 100644 --- a/phasm/wasm.py +++ b/phasm/wasm.py @@ -168,9 +168,6 @@ class ModuleMemory(WatSerializable): self.data = data def to_wat(self) -> str: - """ - Renders this memory as WebAssembly Text - """ data = ''.join( f'\\{x:02x}' for x in self.data diff --git a/pylintrc b/pylintrc index eb6abcb..bfa8c51 100644 --- a/pylintrc +++ b/pylintrc @@ -1,2 +1,2 @@ [MASTER] -disable=C0122,R0903,R0913 +disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index d19c697..cfad8ed 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -14,8 +14,9 @@ import wasmer_compiler_cranelift import wasmtime -from phasm.utils import our_process -from phasm.compiler import module +from phasm.codestyle import phasm_render +from phasm.compiler import phasm_compile +from phasm.parser import phasm_parse DASHES = '-' * 16 @@ -62,9 +63,8 @@ class Suite: """ WebAssembly test suite """ - def __init__(self, code_py, test_name): + def __init__(self, code_py): self.code_py = code_py - self.test_name = test_name def run_code(self, *args, runtime='pywasm3', imports=None): """ @@ -73,15 +73,15 @@ class Suite: Returned is an object with the results set """ - our_module = our_process(self.code_py, self.test_name) + phasm_module = phasm_parse(self.code_py) # Check if code formatting works - assert self.code_py == '\n' + our_module.render() # \n for formatting in tests + assert self.code_py == '\n' + phasm_render(phasm_module) # \n for formatting in tests # Compile - wasm_module = module(our_module) + wasm_module = phasm_compile(phasm_module) - # Render as text + # Render as WebAssembly text code_wat = wasm_module.to_wat() sys.stderr.write(f'{DASHES} Assembly {DASHES}\n') diff --git a/tests/integration/test_fib.py b/tests/integration/test_fib.py index b27297d..20e7e63 100644 --- a/tests/integration/test_fib.py +++ b/tests/integration/test_fib.py @@ -25,6 +25,6 @@ def testEntry() -> i32: return fib(40) """ - result = Suite(code_py, 'test_fib').run_code() + result = Suite(code_py).run_code() assert 102334155 == result.returned_value diff --git a/tests/integration/test_runtime_checks.py b/tests/integration/test_runtime_checks.py index abe7081..6a70032 100644 --- a/tests/integration/test_runtime_checks.py +++ b/tests/integration/test_runtime_checks.py @@ -10,6 +10,6 @@ def testEntry(f: bytes) -> u8: return f[50] """ - result = Suite(code_py, 'test_call').run_code(b'Short', b'Long' * 100) + result = Suite(code_py).run_code(b'Short', b'Long' * 100) assert 0 == result.returned_value diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index 9bb3139..51586af 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -19,7 +19,7 @@ def testEntry() -> {type_}: return 13 """ - result = Suite(code_py, 'test_return').run_code() + result = Suite(code_py).run_code() assert 13 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @@ -33,7 +33,7 @@ def testEntry() -> {type_}: return 10 + 3 """ - result = Suite(code_py, 'test_addition').run_code() + result = Suite(code_py).run_code() assert 13 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @@ -47,7 +47,7 @@ def testEntry() -> {type_}: return 10 - 3 """ - result = Suite(code_py, 'test_addition').run_code() + result = Suite(code_py).run_code() assert 7 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @@ -61,7 +61,7 @@ def testEntry() -> {type_}: return sqrt(25) """ - result = Suite(code_py, 'test_addition').run_code() + result = Suite(code_py).run_code() assert 5 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @@ -75,7 +75,7 @@ def testEntry(a: {type_}) -> {type_}: return a """ - result = Suite(code_py, 'test_return').run_code(125) + result = Suite(code_py).run_code(125) assert 125 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @@ -89,7 +89,7 @@ def testEntry(a: i32) -> i64: return a """ - result = Suite(code_py, 'test_return').run_code(125) + result = Suite(code_py).run_code(125) assert 125 == result.returned_value assert [] == result.log_int32_list @@ -103,7 +103,7 @@ def testEntry(a: i32, b: i64) -> i64: return a + b """ - result = Suite(code_py, 'test_return').run_code(125, 100) + result = Suite(code_py).run_code(125, 100) assert 225 == result.returned_value assert [] == result.log_int32_list @@ -117,7 +117,7 @@ def testEntry(a: f32) -> f64: return a """ - result = Suite(code_py, 'test_return').run_code(125.5) + result = Suite(code_py).run_code(125.5) assert 125.5 == result.returned_value assert [] == result.log_int32_list @@ -131,7 +131,7 @@ def testEntry(a: f32, b: f64) -> f64: return a + b """ - result = Suite(code_py, 'test_return').run_code(125.5, 100.25) + result = Suite(code_py).run_code(125.5, 100.25) assert 225.75 == result.returned_value assert [] == result.log_int32_list @@ -145,7 +145,7 @@ def testEntry() -> i32: return +523 """ - result = Suite(code_py, 'test_addition').run_code() + result = Suite(code_py).run_code() assert 523 == result.returned_value assert [] == result.log_int32_list @@ -159,7 +159,7 @@ def testEntry() -> i32: return -19 """ - result = Suite(code_py, 'test_addition').run_code() + result = Suite(code_py).run_code() assert -19 == result.returned_value assert [] == result.log_int32_list @@ -177,7 +177,7 @@ def testEntry(a: i32) -> i32: """ exp_result = 15 if inp > 10 else 3 - suite = Suite(code_py, 'test_return') + suite = Suite(code_py) result = suite.run_code(inp) assert exp_result == result.returned_value @@ -198,7 +198,7 @@ def testEntry(a: i32) -> i32: return -1 # Required due to function type """ - suite = Suite(code_py, 'test_return') + suite = Suite(code_py) assert 10 == suite.run_code(20).returned_value assert 10 == suite.run_code(10).returned_value @@ -208,6 +208,30 @@ def testEntry(a: i32) -> i32: assert 0 == suite.run_code(0).returned_value assert 0 == suite.run_code(-1).returned_value +@pytest.mark.integration_test +def test_if_nested(): + code_py = """ +@exported +def testEntry(a: i32, b: i32) -> i32: + if a > 11: + if b > 11: + return 3 + + return 2 + + if b > 11: + return 1 + + return 0 +""" + + suite = Suite(code_py) + + assert 3 == suite.run_code(20, 20).returned_value + assert 2 == suite.run_code(20, 10).returned_value + assert 1 == suite.run_code(10, 20).returned_value + assert 0 == suite.run_code(10, 10).returned_value + @pytest.mark.integration_test def test_call_pre_defined(): code_py = """ @@ -219,7 +243,7 @@ def testEntry() -> i32: return helper(10, 3) """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 13 == result.returned_value assert [] == result.log_int32_list @@ -235,7 +259,7 @@ def helper(left: i32, right: i32) -> i32: return left - right """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 7 == result.returned_value assert [] == result.log_int32_list @@ -252,7 +276,7 @@ def helper(left: {type_}, right: {type_}) -> {type_}: return left - right """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 22 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @@ -268,7 +292,7 @@ def testEntry() -> i32: return a """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 8947 == result.returned_value assert [] == result.log_int32_list @@ -288,7 +312,7 @@ def helper(cv: CheckedValue) -> {type_}: return cv.value """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 23 == result.returned_value assert [] == result.log_int32_list @@ -309,7 +333,7 @@ def helper(shape: Rectangle) -> i32: return shape.height + shape.width + shape.border """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 252 == result.returned_value assert [] == result.log_int32_list @@ -330,7 +354,7 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32: return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 545 == result.returned_value assert [] == result.log_int32_list @@ -347,7 +371,7 @@ def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}: return vector[0] + vector[1] + vector[2] """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 161 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @@ -363,7 +387,7 @@ def helper(v: (f32, f32, f32, )) -> f32: return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]) """ - result = Suite(code_py, 'test_call').run_code() + result = Suite(code_py).run_code() assert 3.74 < result.returned_value < 3.75 assert [] == result.log_int32_list @@ -376,7 +400,7 @@ def testEntry(f: bytes) -> bytes: return f """ - result = Suite(code_py, 'test_call').run_code(b'This is a test') + result = Suite(code_py).run_code(b'This is a test') assert 4 == result.returned_value @@ -388,7 +412,7 @@ def testEntry(f: bytes) -> i32: return len(f) """ - result = Suite(code_py, 'test_call').run_code(b'This is another test') + result = Suite(code_py).run_code(b'This is another test') assert 20 == result.returned_value @@ -400,7 +424,7 @@ def testEntry(f: bytes) -> u8: return f[8] """ - result = Suite(code_py, 'test_call').run_code(b'This is another test') + result = Suite(code_py).run_code(b'This is another test') assert 0x61 == result.returned_value @@ -413,7 +437,7 @@ def testEntry() -> i32x4: return (51, 153, 204, 0, ) """ - result = Suite(code_py, 'test_rgb2hsl').run_code() + result = Suite(code_py).run_code() assert (1, 2, 3, 0) == result.returned_value @@ -432,7 +456,7 @@ def testEntry() -> i32: def helper(mul: int) -> int: return 4238 * mul - result = Suite(code_py, 'test_imported').run_code( + result = Suite(code_py).run_code( runtime='wasmer', imports={ 'helper': helper, diff --git a/tests/integration/test_static_checking.py b/tests/integration/test_static_checking.py index ad43a7c..69d0724 100644 --- a/tests/integration/test_static_checking.py +++ b/tests/integration/test_static_checking.py @@ -1,7 +1,7 @@ import pytest -from phasm.utils import our_process -from phasm.ourlang import StaticError +from phasm.parser import phasm_parse +from phasm.exceptions import StaticError @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @@ -12,7 +12,7 @@ def helper(a: {type_}) -> (i32, i32, ): """ with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), a is actually {type_}'): - our_process(code_py, 'test') + phasm_parse(code_py) @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @@ -26,7 +26,7 @@ def testEntry(arg: Struct) -> (i32, i32, ): """ with pytest.raises(StaticError, match=f'Static error on line 6: Expected \\(i32, i32, \\), arg.param is actually {type_}'): - our_process(code_py, 'test') + phasm_parse(code_py) @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @@ -37,7 +37,7 @@ def testEntry(arg: ({type_}, )) -> (i32, i32, ): """ with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), arg\\[0\\] is actually {type_}'): - our_process(code_py, 'test') + phasm_parse(code_py) @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @@ -52,4 +52,4 @@ def testEntry() -> (i32, i32, ): """ with pytest.raises(StaticError, match=f'Static error on line 7: Expected \\(i32, i32, \\), helper actually returns {type_}'): - our_process(code_py, 'test') + phasm_parse(code_py)