""" Contains the syntax tree for ourlang """ from typing import Any, Dict, List, Optional, NoReturn, Union, Tuple import ast from typing_extensions import Final WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) class OurType: """ Type base class """ __slots__ = () def render(self) -> str: """ Renders the type back to source code format This'll look like Python code. """ raise NotImplementedError(self, 'render') def alloc_size(self) -> int: """ When allocating this type in memory, how many bytes do we need to reserve? """ raise NotImplementedError(self, 'alloc_size') class OurTypeNone(OurType): """ The None (or Void) type """ __slots__ = () def render(self) -> str: return 'None' class OurTypeInt32(OurType): """ The Integer type, signed and 32 bits wide """ __slots__ = () def render(self) -> str: return 'i32' def alloc_size(self) -> int: return 4 class OurTypeInt64(OurType): """ The Integer type, signed and 64 bits wide """ __slots__ = () def render(self) -> str: return 'i64' def alloc_size(self) -> int: return 8 class OurTypeFloat32(OurType): """ The Float type, 32 bits wide """ __slots__ = () def render(self) -> str: return 'f32' def alloc_size(self) -> int: return 4 class OurTypeFloat64(OurType): """ The Float type, 64 bits wide """ __slots__ = () def render(self) -> str: return 'f64' def alloc_size(self) -> int: return 8 class TupleMember: """ Represents a tuple member """ def __init__(self, idx: int, type_: OurType, offset: int) -> None: self.idx = idx self.type = type_ self.offset = offset class OurTypeTuple(OurType): """ The tuple type, 64 bits wide """ __slots__ = ('members', ) members: List[TupleMember] def __init__(self) -> None: self.members = [] def render(self) -> str: mems = ', '.join(x.type.render() for x in self.members) return f'({mems}, )' def render_internal_name(self) -> str: mems = '@'.join(x.type.render() for x in self.members) assert ' ' not in mems, 'Not implement yet: subtuples' return f'tuple@{mems}' def alloc_size(self) -> int: return sum( x.type.alloc_size() for x in self.members ) class Expression: """ An expression within a statement """ __slots__ = ('type', ) type: OurType def __init__(self, type_: OurType) -> None: self.type = type_ def render(self) -> str: """ Renders the expression back to source code format This'll look like Python code. """ raise NotImplementedError(self, 'render') class Constant(Expression): """ An constant value expression within a statement """ __slots__ = () class ConstantInt32(Constant): """ An Int32 constant value expression within a statement """ __slots__ = ('value', ) value: int def __init__(self, type_: OurTypeInt32, value: int) -> None: super().__init__(type_) self.value = value def render(self) -> str: return str(self.value) class ConstantInt64(Constant): """ An Int64 constant value expression within a statement """ __slots__ = ('value', ) value: int def __init__(self, type_: OurTypeInt64, value: int) -> None: super().__init__(type_) self.value = value def render(self) -> str: return str(self.value) class ConstantFloat32(Constant): """ An Float32 constant value expression within a statement """ __slots__ = ('value', ) value: float def __init__(self, type_: OurTypeFloat32, value: float) -> None: super().__init__(type_) self.value = value def render(self) -> str: return str(self.value) class ConstantFloat64(Constant): """ An Float64 constant value expression within a statement """ __slots__ = ('value', ) value: float def __init__(self, type_: OurTypeFloat64, value: float) -> None: super().__init__(type_) self.value = value def render(self) -> str: return str(self.value) class VariableReference(Expression): """ An variable reference expression within a statement """ __slots__ = ('name', ) name: str def __init__(self, type_: OurType, name: str) -> None: super().__init__(type_) self.name = name def render(self) -> str: return str(self.name) class BinaryOp(Expression): """ A binary operator expression within a statement """ __slots__ = ('operator', 'left', 'right', ) operator: str left: Expression right: Expression def __init__(self, type_: OurType, operator: str, left: Expression, right: Expression) -> None: super().__init__(type_) self.operator = operator self.left = left self.right = right def render(self) -> str: return f'{self.left.render()} {self.operator} {self.right.render()}' class UnaryOp(Expression): """ A unary operator expression within a statement """ __slots__ = ('operator', 'right', ) operator: str right: Expression def __init__(self, type_: OurType, operator: str, right: Expression) -> None: super().__init__(type_) self.operator = operator self.right = right def render(self) -> str: if self.operator in WEBASSEMBLY_BUILDIN_FLOAT_OPS: return f'{self.operator}({self.right.render()})' return f'{self.operator}{self.right.render()}' class FunctionCall(Expression): """ A function call expression within a statement """ __slots__ = ('function', 'arguments', ) function: 'Function' arguments: List[Expression] def __init__(self, function: 'Function') -> None: super().__init__(function.returns) self.function = function self.arguments = [] def render(self) -> str: args = ', '.join( arg.render() for arg in self.arguments ) if isinstance(self.function, StructConstructor): return f'{self.function.struct.name}({args})' if isinstance(self.function, TupleConstructor): return f'({args}, )' return f'{self.function.name}({args})' class AccessStructMember(Expression): """ Access a struct member for reading of writing """ __slots__ = ('varref', 'member', ) varref: VariableReference member: 'StructMember' def __init__(self, varref: VariableReference, member: 'StructMember') -> None: super().__init__(member.type) self.varref = varref self.member = member def render(self) -> str: return f'{self.varref.render()}.{self.member.name}' class AccessTupleMember(Expression): """ Access a tuple member for reading of writing """ __slots__ = ('varref', 'member', ) varref: VariableReference member: TupleMember def __init__(self, varref: VariableReference, member: TupleMember, ) -> None: super().__init__(member.type) self.varref = varref self.member = member def render(self) -> str: return f'{self.varref.render()}[{self.member.idx}]' class Statement: """ A statement within a function """ __slots__ = () def render(self) -> List[str]: """ Renders the type back to source code format This'll look like Python code. """ raise NotImplementedError(self, 'render') class StatementReturn(Statement): """ A return statement within a function """ __slots__ = ('value', ) def __init__(self, value: Expression) -> None: self.value = value def render(self) -> List[str]: """ Renders the type back to source code format This'll look like Python code. """ return [f'return {self.value.render()}'] class StatementIf(Statement): """ An if statement within a function """ __slots__ = ('test', 'statements', 'else_statements', ) test: Expression statements: List[Statement] else_statements: List[Statement] def __init__(self, test: Expression) -> None: self.test = test self.statements = [] self.else_statements = [] def render(self) -> List[str]: """ Renders the type back to source code format This'll look like Python code. """ result = [f'if {self.test.render()}:'] for stmt in self.statements: result.extend( f' {line}' if line else '' for line in stmt.render() ) result.append('') return result class StatementPass(Statement): """ A pass statement """ __slots__ = () def render(self) -> List[str]: return ['pass'] class Function: """ A function processes input and produces output """ __slots__ = ('name', 'lineno', 'exported', 'buildin', 'statements', 'returns', 'posonlyargs', ) name: str lineno: int exported: bool buildin: bool statements: List[Statement] returns: OurType posonlyargs: List[Tuple[str, OurType]] def __init__(self, name: str, lineno: int) -> None: self.name = name self.lineno = lineno self.exported = False self.buildin = False self.statements = [] self.returns = OurTypeNone() self.posonlyargs = [] def render(self) -> str: """ Renders the function back to source code format This'll look like Python code. """ result = '' if self.exported: result += '@exported\n' args = ', '.join( f'{x}: {y.render()}' for x, y in self.posonlyargs ) result += f'def {self.name}({args}) -> {self.returns.render()}:\n' for stmt in self.statements: for line in stmt.render(): result += f' {line}\n' if line else '\n' return result class StructConstructor(Function): """ The constructor method for a struct """ __slots__ = ('struct', ) struct: 'Struct' def __init__(self, struct: 'Struct') -> None: super().__init__(f'@{struct.name}@__init___@', -1) self.returns = struct for mem in struct.members: self.posonlyargs.append((mem.name, mem.type, )) self.struct = struct class TupleConstructor(Function): """ The constructor method for a tuple """ __slots__ = ('tuple', ) tuple: OurTypeTuple def __init__(self, tuple_: OurTypeTuple) -> None: name = tuple_.render_internal_name() super().__init__(f'@{name}@__init___@', -1) self.returns = tuple_ for mem in tuple_.members: self.posonlyargs.append((f'arg{mem.idx}', mem.type, )) self.tuple = tuple_ class StructMember: """ Represents a struct member """ def __init__(self, name: str, type_: OurType, offset: int) -> None: self.name = name self.type = type_ self.offset = offset class Struct(OurType): """ A struct has named properties """ __slots__ = ('name', 'lineno', 'members', ) name: str lineno: int members: List[StructMember] def __init__(self, name: str, lineno: int) -> None: self.name = name self.lineno = lineno self.members = [] def get_member(self, name: str) -> Optional[StructMember]: """ Returns a member by name """ for mem in self.members: if mem.name == name: return mem return None def render(self) -> str: """ Renders the type back to source code format This'll look like Python code. """ return self.name def render_definition(self) -> str: """ Renders the definition back to source code format This'll look like Python code. """ result = f'class {self.name}:\n' for mem in self.members: result += f' {mem.name}: {mem.type.render()}\n' return result def alloc_size(self) -> int: return sum( x.type.alloc_size() for x in self.members ) class Module: """ A module is a file and consists of functions """ __slots__ = ('types', 'functions', 'structs', ) types: Dict[str, OurType] functions: Dict[str, Function] structs: Dict[str, Struct] def __init__(self) -> None: self.types = { 'i32': OurTypeInt32(), 'i64': OurTypeInt64(), 'f32': OurTypeFloat32(), 'f64': OurTypeFloat64(), } self.functions = {} self.structs = {} def render(self) -> str: """ Renders the module back to source code format This'll look like Python code. """ result = '' for struct in self.structs.values(): if result: result += '\n' result += struct.render_definition() for function in self.functions.values(): if function.lineno < 0: # Buildin (-2) or auto generated (-1) continue if result: result += '\n' result += function.render() return result class StaticError(Exception): """ An error found during static analysis """ OurLocals = Dict[str, OurType] class OurVisitor: """ Class to visit a Python syntax tree and create an ourlang syntax tree """ # pylint: disable=C0103,C0116,C0301,R0201,R0912 def __init__(self) -> None: pass def visit_Module(self, node: ast.Module) -> Module: module = Module() _not_implemented(not node.type_ignores, 'Module.type_ignores') # Second pass for the types for stmt in node.body: res = self.pre_visit_Module_stmt(module, stmt) if isinstance(res, Function): if res.name in module.functions: raise StaticError( f'{res.name} already defined on line {module.functions[res.name].lineno}' ) module.functions[res.name] = res if isinstance(res, Struct): if res.name in module.structs: raise StaticError( f'{res.name} already defined on line {module.structs[res.name].lineno}' ) module.structs[res.name] = res constructor = StructConstructor(res) module.functions[constructor.name] = constructor # Second pass for the function bodies for stmt in node.body: self.visit_Module_stmt(module, stmt) return module def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, Struct]: if isinstance(node, ast.FunctionDef): return self.pre_visit_Module_FunctionDef(module, node) if isinstance(node, ast.ClassDef): return self.pre_visit_Module_ClassDef(module, node) raise NotImplementedError(f'{node} on Module') def pre_visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> Function: function = Function(node.name, node.lineno) _not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs') for arg in node.args.args: if not arg.annotation: _raise_static_error(node, 'Type is required') function.posonlyargs.append(( arg.arg, self.visit_type(module, arg.annotation), )) _not_implemented(not node.args.vararg, 'FunctionDef.args.vararg') _not_implemented(not node.args.kwonlyargs, 'FunctionDef.args.kwonlyargs') _not_implemented(not node.args.kw_defaults, 'FunctionDef.args.kw_defaults') _not_implemented(not node.args.kwarg, 'FunctionDef.args.kwarg') _not_implemented(not node.args.defaults, 'FunctionDef.args.defaults') # Do stmts at the end so we have the return value for decorator in node.decorator_list: if not isinstance(decorator, ast.Name): _raise_static_error(decorator, 'Function decorators must be string') if not isinstance(decorator.ctx, ast.Load): _raise_static_error(decorator, 'Must be load context') _not_implemented(decorator.id != 'exports', 'Custom decorators') function.exported = True if node.returns: function.returns = self.visit_type(module, node.returns) _not_implemented(not node.type_comment, 'FunctionDef.type_comment') return function def pre_visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> Struct: struct = Struct(node.name, node.lineno) _not_implemented(not node.bases, 'ClassDef.bases') _not_implemented(not node.keywords, 'ClassDef.keywords') _not_implemented(not node.decorator_list, 'ClassDef.decorator_list') offset = 0 for stmt in node.body: if not isinstance(stmt, ast.AnnAssign): raise NotImplementedError(f'Class with {stmt} nodes') if not isinstance(stmt.target, ast.Name): raise NotImplementedError('Class with default values') if not stmt.value is None: raise NotImplementedError('Class with default values') if stmt.simple != 1: raise NotImplementedError('Class with non-simple arguments') member = StructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) struct.members.append(member) offset += member.type.alloc_size() return struct def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None: if isinstance(node, ast.FunctionDef): self.visit_Module_FunctionDef(module, node) return if isinstance(node, ast.ClassDef): return raise NotImplementedError(f'{node} on Module') def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: function = module.functions[node.name] our_locals = dict(function.posonlyargs) for stmt in node.body: function.statements.append( self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) ) def visit_Module_FunctionDef_stmt(self, module: Module, function: Function, our_locals: OurLocals, node: ast.stmt) -> Statement: if isinstance(node, ast.Return): if node.value is None: # TODO: Implement methods without return values _raise_static_error(node, 'Return must have an argument') return StatementReturn( self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) ) if isinstance(node, ast.If): result = StatementIf( self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) ) for stmt in node.body: result.statements.append( self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) ) for stmt in node.orelse: result.else_statements.append( self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt) ) return result raise NotImplementedError(f'{node} as stmt in FunctionDef') def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.expr) -> Expression: if isinstance(node, ast.BinOp): if isinstance(node.op, ast.Add): operator = '+' elif isinstance(node.op, ast.Sub): operator = '-' elif isinstance(node.op, ast.Mult): operator = '*' else: raise NotImplementedError(f'Operator {node.op}') # Assume the type doesn't change when descending into a binary operator # e.g. you can do `"hello" * 3` with the code below (yet) return BinaryOp( exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), ) if isinstance(node, ast.UnaryOp): if isinstance(node.op, ast.UAdd): operator = '+' elif isinstance(node.op, ast.USub): operator = '-' else: raise NotImplementedError(f'Operator {node.op}') return UnaryOp( exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), ) if isinstance(node, ast.Compare): if 1 < len(node.ops): raise NotImplementedError('Multiple operators') if isinstance(node.ops[0], ast.Gt): operator = '>' elif isinstance(node.ops[0], ast.Eq): operator = '==' elif isinstance(node.ops[0], ast.Lt): operator = '<' else: raise NotImplementedError(f'Operator {node.ops}') # Assume the type doesn't change when descending into a binary operator # e.g. you can do `"hello" * 3` with the code below (yet) return BinaryOp( exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), ) if isinstance(node, ast.Call): return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) if isinstance(node, ast.Constant): return self.visit_Module_FunctionDef_Constant( module, function, exp_type, node, ) if isinstance(node, ast.Attribute): return self.visit_Module_FunctionDef_Attribute( module, function, our_locals, exp_type, node, ) if isinstance(node, ast.Subscript): return self.visit_Module_FunctionDef_Subscript( module, function, our_locals, exp_type, node, ) if isinstance(node, ast.Name): if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') if node.id in our_locals: return VariableReference(our_locals[node.id], node.id) if isinstance(node, ast.Tuple): if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') if not isinstance(exp_type, OurTypeTuple): _raise_static_error(node, f'Expression is expecting a {exp_type.render()}, not a tuple') if len(exp_type.members) != len(node.elts): _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') tuple_constructor = TupleConstructor(exp_type) func = module.functions[tuple_constructor.name] result = FunctionCall(func) result.arguments = [ self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) for arg_node, mem in zip(node.elts, exp_type.members) ] return result raise NotImplementedError(f'{node} as expr in FunctionDef') def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.Call) -> Union[FunctionCall, UnaryOp]: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? if not isinstance(node.func, ast.Name): raise NotImplementedError(f'Calling methods that are not a name {node.func}') if not isinstance(node.func.ctx, ast.Load): _raise_static_error(node, 'Must be load context') if node.func.id in module.structs: struct = module.structs[node.func.id] struct_constructor = StructConstructor(struct) func = module.functions[struct_constructor.name] elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: if not isinstance(exp_type, (OurTypeFloat32, OurTypeFloat64, )): _raise_static_error(node, f'Cannot make square root result in {exp_type}') if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( exp_type, 'sqrt', self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), ) else: if node.func.id not in module.functions: _raise_static_error(node, 'Call to undefined function') func = module.functions[node.func.id] if func.returns != exp_type: _raise_static_error(node, f'Function {node.func.id} does not return {exp_type.render()}') if len(func.posonlyargs) != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') result = FunctionCall(func) result.arguments.extend( self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) ) return result def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.Attribute) -> Expression: if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must reference a name') if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') if not node.value.id in our_locals: _raise_static_error(node, f'Undefined variable {node.value.id}') node_typ = our_locals[node.value.id] if not isinstance(node_typ, Struct): _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') member = node_typ.get_member(node.attr) if member is None: _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') if exp_type != member.type: _raise_static_error(node, f'Expected {exp_type.render()}, got {member.type.render()} instead') return AccessStructMember( VariableReference(node_typ, node.value.id), member, ) def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: OurType, node: ast.Subscript) -> Expression: if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must reference a name') if not isinstance(node.slice, ast.Index): _raise_static_error(node, 'Must subscript using an index') if not isinstance(node.slice.value, ast.Constant): _raise_static_error(node, 'Must subscript using a constant index') # FIXME: Implement variable indexes if not isinstance(node.slice.value.value, int): _raise_static_error(node, 'Must subscript using a constant integer index') if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') if not node.value.id in our_locals: _raise_static_error(node, f'Undefined variable {node.value.id}') node_typ = our_locals[node.value.id] if not isinstance(node_typ, OurTypeTuple): _raise_static_error(node, f'Cannot take index of non-tuple {node.value.id}') if len(node_typ.members) <= node.slice.value.value: _raise_static_error(node, f'Index {node.slice.value.value} out of bounds for tuple {node.value.id}') member = node_typ.members[node.slice.value.value] if exp_type != member.type: _raise_static_error(node, f'Expected {exp_type.render()}, got {member.type.render()} instead') return AccessTupleMember( VariableReference(node_typ, node.value.id), member, ) def visit_Module_FunctionDef_Constant(self, module: Module, function: Function, exp_type: OurType, node: ast.Constant) -> Expression: del module del function _not_implemented(node.kind is None, 'Constant.kind') if isinstance(exp_type, OurTypeInt32): if not isinstance(node.value, int): _raise_static_error(node, 'Expected integer value') # FIXME: Range check return ConstantInt32(exp_type, node.value) if isinstance(exp_type, OurTypeInt64): if not isinstance(node.value, int): _raise_static_error(node, 'Expected integer value') # FIXME: Range check return ConstantInt64(exp_type, node.value) if isinstance(exp_type, OurTypeFloat32): if not isinstance(node.value, (float, int, )): _raise_static_error(node, 'Expected float value') # FIXME: Range check return ConstantFloat32(exp_type, node.value) if isinstance(exp_type, OurTypeFloat64): if not isinstance(node.value, (float, int, )): _raise_static_error(node, 'Expected float value') # FIXME: Range check return ConstantFloat64(exp_type, node.value) raise NotImplementedError(f'{node} as const for type {exp_type.render()}') def visit_type(self, module: Module, node: ast.expr) -> OurType: if isinstance(node, ast.Name): if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') if node.id in module.types: return module.types[node.id] if node.id in module.structs: return module.structs[node.id] _raise_static_error(node, f'Unrecognized type {node.id}') if isinstance(node, ast.Tuple): if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') result = OurTypeTuple() offset = 0 for idx, elt in enumerate(node.elts): member = TupleMember(idx, self.visit_type(module, elt), offset) result.members.append(member) offset += member.type.alloc_size() key = result.render_internal_name() if key not in module.types: module.types[key] = result constructor = TupleConstructor(result) module.functions[constructor.name] = constructor return module.types[key] raise NotImplementedError(f'{node} as type') def _not_implemented(check: Any, msg: str) -> None: if not check: raise NotImplementedError(msg) def _raise_static_error(node: Union[ast.mod, ast.stmt, ast.expr], msg: str) -> NoReturn: raise StaticError( f'Static error on line {node.lineno}: {msg}' )