phasm/phasm/codestyle.py
Johan B.W. de Vries f8d107f4fa Replaces did_construct with a proper router
By annotating types with the constructor application
that was used to create them.

Later on we can use the router to replace compiler's
INSTANCES or for user defined types.
2025-05-10 16:49:10 +02:00

191 lines
4.9 KiB
Python

"""
This module generates source code based on the parsed AST
It's intented to be a "any color, as long as it's black" kind of renderer
"""
from typing import Generator
from . import ourlang, prelude
from .type3.types import Type3, TypeApplication_Struct
def phasm_render(inp: ourlang.Module) -> str:
"""
Public method for rendering a Phasm module into Phasm code
"""
return module(inp)
Statements = Generator[str, None, None]
def type3(inp: Type3) -> str:
"""
Render: type's name
"""
if inp is prelude.none:
return 'None'
return inp.name
def struct_definition(inp: ourlang.StructDefinition) -> str:
"""
Render: TypeStruct's definition
"""
assert isinstance(inp.struct_type3.application, TypeApplication_Struct)
result = f'class {inp.struct_type3.name}:\n'
for mem, typ in inp.struct_type3.application.arguments:
result += f' {mem}: {type3(typ)}\n'
return result
def constant_definition(inp: ourlang.ModuleConstantDef) -> str:
"""
Render: Module Constant's definition
"""
return f'{inp.name}: {type3(inp.type3)} = {expression(inp.constant)}\n'
def expression(inp: ourlang.Expression) -> str:
"""
Render: A Phasm expression
"""
if isinstance(inp, ourlang.ConstantPrimitive):
# Floats might not round trip if the original constant
# could not fit in the given float type
return str(inp.value)
if isinstance(inp, ourlang.ConstantBytes):
return repr(inp.value)
if isinstance(inp, ourlang.ConstantTuple):
return '(' + ', '.join(
expression(x)
for x in inp.value
) + ', )'
if isinstance(inp, ourlang.ConstantStruct):
return inp.struct_name + '(' + ', '.join(
expression(x)
for x in inp.value
) + ')'
if isinstance(inp, ourlang.VariableReference):
return str(inp.variable.name)
if isinstance(inp, ourlang.BinaryOp):
return f'{expression(inp.left)} {inp.operator.name} {expression(inp.right)}'
if isinstance(inp, ourlang.FunctionCall):
args = ', '.join(
expression(arg)
for arg in inp.arguments
)
if isinstance(inp.function, ourlang.StructConstructor):
return f'{inp.function.struct_type3.name}({args})'
return f'{inp.function.name}({args})'
if isinstance(inp, ourlang.TupleInstantiation):
args = ', '.join(
expression(arg)
for arg in inp.elements
)
return f'({args}, )'
if isinstance(inp, ourlang.Subscript):
varref = expression(inp.varref)
index = expression(inp.index)
return f'{varref}[{index}]'
if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements:
"""
Render: A list of Phasm statements
"""
if isinstance(inp, ourlang.StatementPass):
yield 'pass'
return
if isinstance(inp, ourlang.StatementReturn):
yield f'return {expression(inp.value)}'
return
if isinstance(inp, ourlang.StatementIf):
yield f'if {expression(inp.test)}:'
for stmt in inp.statements:
for line in statement(stmt):
yield f' {line}' if line else ''
yield ''
return
raise NotImplementedError(statement, inp)
def function(inp: ourlang.Function) -> str:
"""
Render: Function body
Imported functions only have "pass" as a body. Later on we might replace
this by the function documentation, if any.
"""
result = ''
if inp.exported:
result += '@exported\n'
if inp.imported:
result += '@imported\n'
args = ', '.join(
f'{p.name}: {type3(p.type3)}'
for p in inp.posonlyargs
)
result += f'def {inp.name}({args}) -> {type3(inp.returns_type3)}:\n'
if inp.imported:
result += ' pass\n'
else:
for stmt in inp.statements:
for line in statement(stmt):
result += f' {line}\n' if line else '\n'
return result
def module(inp: ourlang.Module) -> str:
"""
Render: Module
"""
result = ''
for struct in inp.struct_definitions.values():
if result:
result += '\n'
result += struct_definition(struct)
for cdef in inp.constant_defs.values():
if result:
result += '\n'
result += constant_definition(cdef)
for func in inp.functions.values():
if func.lineno < 0:
# Builtin (-2) or auto generated (-1)
continue
if result:
result += '\n'
result += function(func)
return result