phasm/phasm/parser.py
Johan B.W. de Vries 1d107d1baa Notes
2025-07-19 19:41:47 +02:00

704 lines
26 KiB
Python

"""
Parses the source code from the plain text into a syntax tree
"""
import ast
from typing import Any, Dict, NoReturn, Union
from .build.base import BuildBase
from .build.default import BuildDefault
from .exceptions import StaticError
from .ourlang import (
AccessStructMember,
BinaryOp,
ConstantBytes,
ConstantPrimitive,
ConstantStruct,
ConstantTuple,
Expression,
Function,
FunctionCall,
FunctionParam,
FunctionReference,
Module,
ModuleConstantDef,
ModuleDataBlock,
Statement,
StatementIf,
StatementPass,
StatementReturn,
StructConstructor,
StructDefinition,
Subscript,
TupleInstantiation,
VariableReference,
)
from .type3.typeclasses import Type3ClassMethod
from .type3.types import IntType3, Type3
from .type5 import typeexpr as type5typeexpr
from .wasmgenerator import Generator
def phasm_parse(source: str) -> Module[Generator]:
"""
Public method for parsing Phasm code into a Phasm Module
"""
res = ast.parse(source, '')
res = OptimizerTransformer().visit(res)
build = BuildDefault()
our_visitor = OurVisitor(build)
return our_visitor.visit_Module(res)
OurLocals = Dict[str, Union[FunctionParam]] # FIXME: Does it become easier if we add ModuleConstantDef to this dict?
class OptimizerTransformer(ast.NodeTransformer):
"""
This class optimizes the Python AST, to prepare it for parsing
by the OurVisitor class below.
"""
def visit_UnaryOp(self, node: ast.UnaryOp) -> Union[ast.UnaryOp, ast.Constant]:
"""
UnaryOp optimizations
In the given example:
```py
x = -4
```
Python will parse it as a unary minus operation on the constant four.
For Phasm purposes, this counts as a literal -4.
"""
if (
isinstance(node.op, (ast.UAdd, ast.USub, ))
and isinstance(node.operand, ast.Constant)
and isinstance(node.operand.value, (int, float, ))
):
if isinstance(node.op, ast.USub):
node.operand.value = -node.operand.value
return node.operand
return node
class OurVisitor[G]:
"""
Class to visit a Python syntax tree and create an ourlang syntax tree
We're (ab)using the Python AST parser to give us a leg up
At some point, we may deviate from Python syntax. If nothing else,
we probably won't keep up with the Python syntax changes.
See OptimizerTransformer for the changes we make after the Python
parsing is done but before the phasm parsing is done.
"""
# pylint: disable=C0103,C0116,C0301,R0201,R0912
def __init__(self, build: BuildBase[G]) -> None:
self.build = build
def visit_Module(self, node: ast.Module) -> Module[G]:
module = Module(self.build)
module.methods.update(self.build.methods)
module.operators.update(self.build.operators)
module.types.update(self.build.types)
module.type5s.update(self.build.type5s)
_not_implemented(not node.type_ignores, 'Module.type_ignores')
# Second pass for the types
for stmt in node.body:
res = self.pre_visit_Module_stmt(module, stmt)
if isinstance(res, ModuleConstantDef):
if res.name in module.constant_defs:
raise StaticError(
f'{res.name} already defined on line {module.constant_defs[res.name].lineno}'
)
module.constant_defs[res.name] = res
if isinstance(res, StructDefinition):
if res.struct_type3.name in module.types:
raise StaticError(
f'{res.struct_type3.name} already defined as type'
)
module.types[res.struct_type3.name] = res.struct_type3
module.functions[res.struct_type3.name] = StructConstructor(res.struct_type3)
# Store that the definition was done in this module for the formatter
module.struct_definitions[res.struct_type3.name] = res
if isinstance(res, Function):
if res.name in module.functions:
raise StaticError(
f'{res.name} already defined on line {module.functions[res.name].lineno}'
)
module.functions[res.name] = res
# Second pass for the function bodies
for stmt in node.body:
self.visit_Module_stmt(module, stmt)
return module
def pre_visit_Module_stmt(self, module: Module[G], node: ast.stmt) -> Union[Function, StructDefinition, ModuleConstantDef]:
if isinstance(node, ast.FunctionDef):
return self.pre_visit_Module_FunctionDef(module, node)
if isinstance(node, ast.ClassDef):
return self.pre_visit_Module_ClassDef(module, node)
if isinstance(node, ast.AnnAssign):
return self.pre_visit_Module_AnnAssign(module, node)
raise NotImplementedError(f'{node} on Module')
def pre_visit_Module_FunctionDef(self, module: Module[G], node: ast.FunctionDef) -> Function:
function = Function(node.name, node.lineno, self.build.none_)
_not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs')
arg_type5_list = []
for arg in node.args.args:
if arg.annotation is None:
_raise_static_error(node, 'Must give a argument type')
arg_type = self.visit_type(module, arg.annotation)
arg_type5_list.append(self.visit_type5(module, arg.annotation))
# FIXME: Allow TypeVariable in the function signature
# This would also require FunctionParam to accept a placeholder
function.signature.args.append(arg_type)
function.posonlyargs.append(FunctionParam(
arg.arg,
arg_type,
))
_not_implemented(not node.args.vararg, 'FunctionDef.args.vararg')
_not_implemented(not node.args.kwonlyargs, 'FunctionDef.args.kwonlyargs')
_not_implemented(not node.args.kw_defaults, 'FunctionDef.args.kw_defaults')
_not_implemented(not node.args.kwarg, 'FunctionDef.args.kwarg')
_not_implemented(not node.args.defaults, 'FunctionDef.args.defaults')
# Do stmts at the end so we have the return value
for decorator in node.decorator_list:
if isinstance(decorator, ast.Call):
if not isinstance(decorator.func, ast.Name):
_raise_static_error(decorator, 'Function decorators must be string')
if not isinstance(decorator.func.ctx, ast.Load):
_raise_static_error(decorator, 'Must be load context')
_not_implemented(decorator.func.id == 'imported', 'Custom decorators')
if 1 != len(decorator.args):
_raise_static_error(decorator, 'One argument expected')
if not isinstance(decorator.args[0], ast.Constant):
_raise_static_error(decorator, 'Service name must be a constant')
if not isinstance(decorator.args[0].value, str):
_raise_static_error(decorator, 'Service name must be a stirng')
if 0 != len(decorator.keywords): # TODO: Allow for namespace keyword
_raise_static_error(decorator, 'No keyword arguments expected')
function.imported = decorator.args[0].value
else:
if not isinstance(decorator, ast.Name):
_raise_static_error(decorator, 'Function decorators must be string')
if not isinstance(decorator.ctx, ast.Load):
_raise_static_error(decorator, 'Must be load context')
_not_implemented(decorator.id in ('exported', 'imported'), 'Custom decorators')
if decorator.id == 'exported':
function.exported = True
else:
function.imported = 'imports'
if node.returns is None: # Note: `-> None` would be a ast.Constant
_raise_static_error(node, 'Must give a return type')
return_type = self.visit_type(module, node.returns)
arg_type5_list.append(self.visit_type5(module, node.returns))
function.signature.args.append(return_type)
function.returns_type3 = return_type
for arg_type5 in reversed(arg_type5_list):
if function.type5 is None:
function.type5 = arg_type5
continue
raise NotImplementedError('TODO: Applying the function type')
_not_implemented(not node.type_comment, 'FunctionDef.type_comment')
return function
def pre_visit_Module_ClassDef(self, module: Module[G], node: ast.ClassDef) -> StructDefinition:
_not_implemented(not node.bases, 'ClassDef.bases')
_not_implemented(not node.keywords, 'ClassDef.keywords')
_not_implemented(not node.decorator_list, 'ClassDef.decorator_list')
members: Dict[str, Type3] = {}
for stmt in node.body:
if not isinstance(stmt, ast.AnnAssign):
raise NotImplementedError(f'Class with {stmt} nodes')
if not isinstance(stmt.target, ast.Name):
raise NotImplementedError('Class with default values')
if stmt.value is not None:
raise NotImplementedError('Class with default values')
if stmt.simple != 1:
raise NotImplementedError('Class with non-simple arguments')
if stmt.target.id in members:
_raise_static_error(stmt, 'Struct members must have unique names')
members[stmt.target.id] = self.visit_type(module, stmt.annotation)
return StructDefinition(module.build.struct(node.name, tuple(members.items())), node.lineno)
def pre_visit_Module_AnnAssign(self, module: Module[G], node: ast.AnnAssign) -> ModuleConstantDef:
if not isinstance(node.target, ast.Name):
_raise_static_error(node.target, 'Must be name')
if not isinstance(node.target.ctx, ast.Store):
_raise_static_error(node.target, 'Must be store context')
if isinstance(node.value, ast.Constant):
value_data = self.visit_Module_Constant(module, node.value)
return ModuleConstantDef(
node.target.id,
node.lineno,
self.visit_type(module, node.annotation),
self.visit_type5(module, node.annotation),
value_data,
)
if isinstance(node.value, ast.Tuple):
value_data = self.visit_Module_Constant(module, node.value)
assert isinstance(value_data, ConstantTuple) # type hint
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
self.visit_type(module, node.annotation),
self.visit_type5(module, node.annotation),
value_data,
)
if isinstance(node.value, ast.Call):
value_data = self.visit_Module_Constant(module, node.value)
assert isinstance(value_data, ConstantStruct) # type hint
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
self.visit_type(module, node.annotation),
self.visit_type5(module, node.annotation),
value_data,
)
raise NotImplementedError(f'{node} on Module AnnAssign')
def visit_Module_stmt(self, module: Module[G], node: ast.stmt) -> None:
if isinstance(node, ast.FunctionDef):
self.visit_Module_FunctionDef(module, node)
return
if isinstance(node, ast.ClassDef):
return
if isinstance(node, ast.AnnAssign):
return
raise NotImplementedError(f'{node} on Module')
def visit_Module_FunctionDef(self, module: Module[G], node: ast.FunctionDef) -> None:
function = module.functions[node.name]
our_locals: OurLocals = {
x.name: x
for x in function.posonlyargs
}
for stmt in node.body:
function.statements.append(
self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt)
)
def visit_Module_FunctionDef_stmt(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.stmt) -> Statement:
if isinstance(node, ast.Return):
if node.value is None:
# TODO: Implement methods without return values
_raise_static_error(node, 'Return must have an argument')
return StatementReturn(
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value)
)
if isinstance(node, ast.If):
result = StatementIf(
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.test)
)
for stmt in node.body:
result.statements.append(
self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt)
)
for stmt in node.orelse:
result.else_statements.append(
self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt)
)
return result
if isinstance(node, ast.Pass):
return StatementPass()
raise NotImplementedError(f'{node} as stmt in FunctionDef')
def visit_Module_FunctionDef_expr(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.expr) -> Expression:
if isinstance(node, ast.BinOp):
operator: Union[str, Type3ClassMethod]
if isinstance(node.op, ast.Add):
operator = '+'
elif isinstance(node.op, ast.Sub):
operator = '-'
elif isinstance(node.op, ast.Mult):
operator = '*'
elif isinstance(node.op, ast.Div):
operator = '/'
elif isinstance(node.op, ast.FloorDiv):
operator = '//'
elif isinstance(node.op, ast.Mod):
operator = '%'
elif isinstance(node.op, ast.LShift):
operator = '<<'
elif isinstance(node.op, ast.RShift):
operator = '>>'
elif isinstance(node.op, ast.BitOr):
operator = '|'
elif isinstance(node.op, ast.BitXor):
operator = '^'
elif isinstance(node.op, ast.BitAnd):
operator = '&'
else:
raise NotImplementedError(f'Operator {node.op}')
if operator not in module.operators:
raise NotImplementedError(f'Operator {operator}')
return BinaryOp(
module.operators[operator],
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
)
if isinstance(node, ast.Compare):
if 1 < len(node.ops):
raise NotImplementedError('Multiple operators')
if isinstance(node.ops[0], ast.Gt):
operator = '>'
elif isinstance(node.ops[0], ast.GtE):
operator = '>='
elif isinstance(node.ops[0], ast.Eq):
operator = '=='
elif isinstance(node.ops[0], ast.NotEq):
operator = '!='
elif isinstance(node.ops[0], ast.Lt):
operator = '<'
elif isinstance(node.ops[0], ast.LtE):
operator = '<='
else:
raise NotImplementedError(f'Operator {node.ops}')
if operator not in module.operators:
raise NotImplementedError(f'Operator {operator}')
return BinaryOp(
module.operators[operator],
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
)
if isinstance(node, ast.Call):
return self.visit_Module_FunctionDef_Call(module, function, our_locals, node)
if isinstance(node, ast.Constant):
return self.visit_Module_Constant(
module, node,
)
if isinstance(node, ast.Attribute):
return self.visit_Module_FunctionDef_Attribute(
module, function, our_locals, node,
)
if isinstance(node, ast.Subscript):
return self.visit_Module_FunctionDef_Subscript(
module, function, our_locals, node,
)
if isinstance(node, ast.Name):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.id in our_locals:
param = our_locals[node.id]
return VariableReference(param)
if node.id in module.constant_defs:
cdef = module.constant_defs[node.id]
return VariableReference(cdef)
if node.id in module.functions:
fun = module.functions[node.id]
return FunctionReference(fun)
_raise_static_error(node, f'Undefined variable {node.id}')
if isinstance(node, ast.Tuple):
arguments = [
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_node)
for arg_node in node.elts
if isinstance(arg_node, (ast.Constant, ast.Tuple, ast.Call, ))
]
if len(arguments) != len(node.elts):
raise NotImplementedError('Non-constant tuple members')
return TupleInstantiation(arguments)
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
if not isinstance(node.func, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {node.func}')
if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
func: Union[Function, FunctionParam, Type3ClassMethod]
if node.func.id in module.methods:
func = module.methods[node.func.id]
elif node.func.id in our_locals:
func = our_locals[node.func.id]
else:
if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function')
func = module.functions[node.func.id]
result = FunctionCall(func)
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr in node.args
)
return result
def visit_Module_FunctionDef_Attribute(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
varref = self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value)
if not isinstance(varref, VariableReference):
_raise_static_error(node.value, 'Must refer to variable')
return AccessStructMember(
varref,
varref.variable.type3,
node.attr,
)
def visit_Module_FunctionDef_Subscript(self, module: Module[G], function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
varref: VariableReference
if node.value.id in our_locals:
param = our_locals[node.value.id]
varref = VariableReference(param)
elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id]
varref = VariableReference(constant_def)
else:
_raise_static_error(node, f'Undefined variable {node.value.id}')
slice_expr = self.visit_Module_FunctionDef_expr(
module, function, our_locals, node.slice,
)
return Subscript(varref, slice_expr)
def visit_Module_Constant(self, module: Module[G], node: Union[ast.Constant, ast.Tuple, ast.Call]) -> Union[ConstantPrimitive, ConstantBytes, ConstantTuple, ConstantStruct]:
if isinstance(node, ast.Tuple):
tuple_data = [
self.visit_Module_Constant(module, arg_node)
for arg_node in node.elts
if isinstance(arg_node, (ast.Constant, ast.Tuple, ast.Call, ))
]
if len(node.elts) != len(tuple_data):
_raise_static_error(node, 'Tuple arguments must be constants')
# Allocate the data
data_block = ModuleDataBlock(tuple_data)
module.data.blocks.append(data_block)
return ConstantTuple(tuple_data, data_block)
if isinstance(node, ast.Call):
# Struct constant
# Stored in memory like a tuple, so much of the code is the same
if not isinstance(node.func, ast.Name):
_raise_static_error(node.func, 'Must be name')
if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node.func, 'Must be load context')
struct_def = module.struct_definitions.get(node.func.id)
if struct_def is None:
_raise_static_error(node.func, 'Undefined struct')
if node.keywords:
_raise_static_error(node.func, 'Cannot use keywords')
struct_data = [
self.visit_Module_Constant(module, arg_node)
for arg_node in node.args
if isinstance(arg_node, (ast.Constant, ast.Tuple, ast.Call, ))
]
if len(node.args) != len(struct_data):
_raise_static_error(node, 'Struct arguments must be constants')
data_block = ModuleDataBlock(struct_data)
module.data.blocks.append(data_block)
return ConstantStruct(struct_def.struct_type3, struct_data, data_block)
_not_implemented(node.kind is None, 'Constant.kind')
if isinstance(node.value, (int, float, )):
return ConstantPrimitive(node.value)
if isinstance(node.value, bytes):
data_block = ModuleDataBlock([])
module.data.blocks.append(data_block)
result = ConstantBytes(node.value, data_block)
data_block.data.append(result)
return result
raise NotImplementedError(f'{node.value} as constant')
def visit_type(self, module: Module[G], node: ast.expr) -> Type3:
if isinstance(node, ast.Constant):
if node.value is None:
return module.types['None']
_raise_static_error(node, f'Unrecognized type {node.value}')
if isinstance(node, ast.Name):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.id in module.types:
return module.types[node.id]
_raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Subscript):
if isinstance(node.value, ast.Name) and node.value.id == 'Callable':
func_arg_types: list[ast.expr]
if isinstance(node.slice, ast.Name):
func_arg_types = [node.slice]
elif isinstance(node.slice, ast.Tuple):
func_arg_types = node.slice.elts
else:
_raise_static_error(node, 'Must subscript using a list of types')
# Function type
return module.build.function(*[
self.visit_type(module, e)
for e in func_arg_types
])
if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice, ast.Constant):
_raise_static_error(node, 'Must subscript using a constant index')
if node.slice.value is Ellipsis:
return module.build.dynamic_array(
self.visit_type(module, node.value),
)
if not isinstance(node.slice.value, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
return module.build.static_array(
self.visit_type(module, node.value),
IntType3(node.slice.value),
)
if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
return module.build.tuple_(
*(self.visit_type(module, elt) for elt in node.elts)
)
raise NotImplementedError(f'{node} as type')
def visit_type5(self, module: Module[G], node: ast.expr) -> type5typeexpr.TypeExpr:
if isinstance(node, ast.Name):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.id in module.type5s:
return module.type5s[node.id]
_raise_static_error(node, f'Unrecognized type {node.id}')
raise NotImplementedError
def _not_implemented(check: Any, msg: str) -> None:
if not check:
raise NotImplementedError(msg)
def _raise_static_error(node: Union[ast.stmt, ast.expr], msg: str) -> NoReturn:
raise StaticError(
f'Static error on line {node.lineno}: {msg}'
)