phasm/phasm/parser.py
2022-09-16 17:01:23 +02:00

851 lines
33 KiB
Python

"""
Parses the source code from the plain text into a syntax tree
"""
from typing import Any, Dict, NoReturn, Union
import ast
from .typing import (
TypeBase,
TypeUInt8,
TypeUInt32,
TypeUInt64,
TypeInt32,
TypeInt64,
TypeFloat32,
TypeFloat64,
TypeBytes,
TypeStruct,
TypeStructMember,
TypeTuple,
TypeTupleMember,
TypeStaticArray,
TypeStaticArrayMember,
)
from . import codestyle
from .exceptions import StaticError
from .ourlang import (
WEBASSEMBLY_BUILDIN_FLOAT_OPS,
Module, ModuleDataBlock,
Function,
Expression,
AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember,
BinaryOp,
Constant,
ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64,
ConstantUInt8, ConstantUInt32, ConstantUInt64,
ConstantTuple, ConstantStaticArray,
FunctionCall,
StructConstructor, TupleConstructor,
UnaryOp, VariableReference,
Fold, ModuleConstantReference,
Statement,
StatementIf, StatementPass, StatementReturn,
FunctionParam,
ModuleConstantDef,
)
def phasm_parse(source: str) -> Module:
"""
Public method for parsing Phasm code into a Phasm Module
"""
res = ast.parse(source, '')
our_visitor = OurVisitor()
return our_visitor.visit_Module(res)
OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants?
class OurVisitor:
"""
Class to visit a Python syntax tree and create an ourlang syntax tree
We're (ab)using the Python AST parser to give us a leg up
At some point, we may deviate from Python syntax. If nothing else,
we probably won't keep up with the Python syntax changes.
"""
# pylint: disable=C0103,C0116,C0301,R0201,R0912
def __init__(self) -> None:
pass
def visit_Module(self, node: ast.Module) -> Module:
module = Module()
_not_implemented(not node.type_ignores, 'Module.type_ignores')
# Second pass for the types
for stmt in node.body:
res = self.pre_visit_Module_stmt(module, stmt)
if isinstance(res, ModuleConstantDef):
if res.name in module.constant_defs:
raise StaticError(
f'{res.name} already defined on line {module.constant_defs[res.name].lineno}'
)
module.constant_defs[res.name] = res
if isinstance(res, TypeStruct):
if res.name in module.structs:
raise StaticError(
f'{res.name} already defined on line {module.structs[res.name].lineno}'
)
module.structs[res.name] = res
constructor = StructConstructor(res)
module.functions[constructor.name] = constructor
if isinstance(res, Function):
if res.name in module.functions:
raise StaticError(
f'{res.name} already defined on line {module.functions[res.name].lineno}'
)
module.functions[res.name] = res
# Second pass for the function bodies
for stmt in node.body:
self.visit_Module_stmt(module, stmt)
return module
def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, TypeStruct, ModuleConstantDef]:
if isinstance(node, ast.FunctionDef):
return self.pre_visit_Module_FunctionDef(module, node)
if isinstance(node, ast.ClassDef):
return self.pre_visit_Module_ClassDef(module, node)
if isinstance(node, ast.AnnAssign):
return self.pre_visit_Module_AnnAssign(module, node)
raise NotImplementedError(f'{node} on Module')
def pre_visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> Function:
function = Function(node.name, node.lineno)
_not_implemented(not node.args.posonlyargs, 'FunctionDef.args.posonlyargs')
for arg in node.args.args:
if not arg.annotation:
_raise_static_error(node, 'Type is required')
function.posonlyargs.append(FunctionParam(
arg.arg,
self.visit_type(module, arg.annotation),
))
_not_implemented(not node.args.vararg, 'FunctionDef.args.vararg')
_not_implemented(not node.args.kwonlyargs, 'FunctionDef.args.kwonlyargs')
_not_implemented(not node.args.kw_defaults, 'FunctionDef.args.kw_defaults')
_not_implemented(not node.args.kwarg, 'FunctionDef.args.kwarg')
_not_implemented(not node.args.defaults, 'FunctionDef.args.defaults')
# Do stmts at the end so we have the return value
for decorator in node.decorator_list:
if not isinstance(decorator, ast.Name):
_raise_static_error(decorator, 'Function decorators must be string')
if not isinstance(decorator.ctx, ast.Load):
_raise_static_error(decorator, 'Must be load context')
_not_implemented(decorator.id in ('exported', 'imported'), 'Custom decorators')
if decorator.id == 'exported':
function.exported = True
else:
function.imported = True
if node.returns:
function.returns = self.visit_type(module, node.returns)
_not_implemented(not node.type_comment, 'FunctionDef.type_comment')
return function
def pre_visit_Module_ClassDef(self, module: Module, node: ast.ClassDef) -> TypeStruct:
struct = TypeStruct(node.name, node.lineno)
_not_implemented(not node.bases, 'ClassDef.bases')
_not_implemented(not node.keywords, 'ClassDef.keywords')
_not_implemented(not node.decorator_list, 'ClassDef.decorator_list')
offset = 0
for stmt in node.body:
if not isinstance(stmt, ast.AnnAssign):
raise NotImplementedError(f'Class with {stmt} nodes')
if not isinstance(stmt.target, ast.Name):
raise NotImplementedError('Class with default values')
if not stmt.value is None:
raise NotImplementedError('Class with default values')
if stmt.simple != 1:
raise NotImplementedError('Class with non-simple arguments')
member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset)
struct.members.append(member)
offset += member.type.alloc_size()
return struct
def pre_visit_Module_AnnAssign(self, module: Module, node: ast.AnnAssign) -> ModuleConstantDef:
if not isinstance(node.target, ast.Name):
_raise_static_error(node, 'Must be name')
if not isinstance(node.target.ctx, ast.Store):
_raise_static_error(node, 'Must be load context')
exp_type = self.visit_type(module, node.annotation)
if isinstance(exp_type, TypeInt32):
if not isinstance(node.value, ast.Constant):
_raise_static_error(node, 'Must be constant')
constant = ModuleConstantDef(
node.target.id,
node.lineno,
exp_type,
self.visit_Module_Constant(module, exp_type, node.value),
None,
)
return constant
if isinstance(exp_type, TypeTuple):
if not isinstance(node.value, ast.Tuple):
_raise_static_error(node, 'Must be tuple')
if len(exp_type.members) != len(node.value.elts):
_raise_static_error(node, 'Invalid number of tuple values')
tuple_data = [
self.visit_Module_Constant(module, mem.type, arg_node)
for arg_node, mem in zip(node.value.elts, exp_type.members)
if isinstance(arg_node, ast.Constant)
]
if len(exp_type.members) != len(tuple_data):
_raise_static_error(node, 'Tuple arguments must be constants')
# Allocate the data
data_block = ModuleDataBlock(tuple_data)
module.data.blocks.append(data_block)
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
exp_type,
ConstantTuple(exp_type, tuple_data),
data_block,
)
if isinstance(exp_type, TypeStaticArray):
if not isinstance(node.value, ast.Tuple):
_raise_static_error(node, 'Must be static array')
if len(exp_type.members) != len(node.value.elts):
_raise_static_error(node, 'Invalid number of static array values')
static_array_data = [
self.visit_Module_Constant(module, exp_type.member_type, arg_node)
for arg_node in node.value.elts
if isinstance(arg_node, ast.Constant)
]
if len(exp_type.members) != len(static_array_data):
_raise_static_error(node, 'Static array arguments must be constants')
# Allocate the data
data_block = ModuleDataBlock(static_array_data)
module.data.blocks.append(data_block)
# Then return the constant as a pointer
return ModuleConstantDef(
node.target.id,
node.lineno,
exp_type,
ConstantStaticArray(exp_type, static_array_data),
data_block,
)
raise NotImplementedError(f'{node} on Module AnnAssign')
def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None:
if isinstance(node, ast.FunctionDef):
self.visit_Module_FunctionDef(module, node)
return
if isinstance(node, ast.ClassDef):
return
if isinstance(node, ast.AnnAssign):
return
raise NotImplementedError(f'{node} on Module')
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
function = module.functions[node.name]
our_locals: OurLocals = {
x.name: x
for x in function.posonlyargs
}
for stmt in node.body:
function.statements.append(
self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt)
)
def visit_Module_FunctionDef_stmt(self, module: Module, function: Function, our_locals: OurLocals, node: ast.stmt) -> Statement:
if isinstance(node, ast.Return):
if node.value is None:
# TODO: Implement methods without return values
_raise_static_error(node, 'Return must have an argument')
return StatementReturn(
self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value)
)
if isinstance(node, ast.If):
result = StatementIf(
self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test)
)
for stmt in node.body:
result.statements.append(
self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt)
)
for stmt in node.orelse:
result.else_statements.append(
self.visit_Module_FunctionDef_stmt(module, function, our_locals, stmt)
)
return result
if isinstance(node, ast.Pass):
return StatementPass()
raise NotImplementedError(f'{node} as stmt in FunctionDef')
def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression:
if isinstance(node, ast.BinOp):
if isinstance(node.op, ast.Add):
operator = '+'
elif isinstance(node.op, ast.Sub):
operator = '-'
elif isinstance(node.op, ast.Mult):
operator = '*'
elif isinstance(node.op, ast.LShift):
operator = '<<'
elif isinstance(node.op, ast.RShift):
operator = '>>'
elif isinstance(node.op, ast.BitOr):
operator = '|'
elif isinstance(node.op, ast.BitXor):
operator = '^'
elif isinstance(node.op, ast.BitAnd):
operator = '&'
else:
raise NotImplementedError(f'Operator {node.op}')
# Assume the type doesn't change when descending into a binary operator
# e.g. you can do `"hello" * 3` with the code below (yet)
return BinaryOp(
exp_type,
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right),
)
if isinstance(node, ast.UnaryOp):
if isinstance(node.op, ast.UAdd):
operator = '+'
elif isinstance(node.op, ast.USub):
operator = '-'
else:
raise NotImplementedError(f'Operator {node.op}')
return UnaryOp(
exp_type,
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand),
)
if isinstance(node, ast.Compare):
if 1 < len(node.ops):
raise NotImplementedError('Multiple operators')
if isinstance(node.ops[0], ast.Gt):
operator = '>'
elif isinstance(node.ops[0], ast.Eq):
operator = '=='
elif isinstance(node.ops[0], ast.Lt):
operator = '<'
else:
raise NotImplementedError(f'Operator {node.ops}')
# Assume the type doesn't change when descending into a binary operator
# e.g. you can do `"hello" * 3` with the code below (yet)
return BinaryOp(
exp_type,
operator,
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]),
)
if isinstance(node, ast.Call):
return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node)
if isinstance(node, ast.Constant):
return self.visit_Module_Constant(
module, exp_type, node,
)
if isinstance(node, ast.Attribute):
return self.visit_Module_FunctionDef_Attribute(
module, function, our_locals, exp_type, node,
)
if isinstance(node, ast.Subscript):
return self.visit_Module_FunctionDef_Subscript(
module, function, our_locals, exp_type, node,
)
if isinstance(node, ast.Name):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.id in our_locals:
param = our_locals[node.id]
if exp_type != param.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}')
return VariableReference(param.type, param)
if node.id in module.constant_defs:
cdef = module.constant_defs[node.id]
if exp_type != cdef.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(cdef.type)}')
return ModuleConstantReference(exp_type, cdef)
_raise_static_error(node, f'Undefined variable {node.id}')
if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if isinstance(exp_type, TypeTuple):
if len(exp_type.members) != len(node.elts):
_raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given')
tuple_constructor = TupleConstructor(exp_type)
func = module.functions[tuple_constructor.name]
result = FunctionCall(func)
result.arguments = [
self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node)
for arg_node, mem in zip(node.elts, exp_type.members)
]
return result
_raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple')
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
if not isinstance(node.func, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {node.func}')
if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.func.id in module.structs:
struct = module.structs[node.func.id]
struct_constructor = StructConstructor(struct)
func = module.functions[struct_constructor.name]
elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS:
if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )):
_raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}')
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp(
exp_type,
'sqrt',
self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]),
)
elif node.func.id == 'u32':
if not isinstance(exp_type, TypeUInt32):
_raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}')
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
# FIXME: This is a stub, proper casting is todo
return UnaryOp(
exp_type,
'cast',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]),
)
elif node.func.id == 'len':
if not isinstance(exp_type, TypeInt32):
_raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}')
if 1 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given')
return UnaryOp(
exp_type,
'len',
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]),
)
elif node.func.id == 'foldl':
# TODO: This should a much more generic function!
# For development purposes, we're assuming you're doing a foldl(Callable[[u8, u8], u8], u8, bytes)
# In the future, we should probably infer the type of the second argument,
# and use it as expected types for the other u8s and the Iterable[u8] (i.e. bytes)
if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic
subnode = node.args[0]
if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id]
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
if exp_type.__class__ != func.returns.__class__:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if func.returns.__class__ != func.posonlyargs[0].type.__class__:
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}')
if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__:
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
return Fold(
exp_type,
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[2]),
)
else:
if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function')
func = module.functions[node.func.id]
# if func.returns != exp_type:
# _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if len(func.posonlyargs) != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given')
result = FunctionCall(func)
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr)
for arg_expr, param in zip(node.args, func.posonlyargs)
)
return result
def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression:
del module
del function
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}')
param = our_locals[node.value.id]
node_typ = param.type
if not isinstance(node_typ, TypeStruct):
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
member = node_typ.get_member(node.attr)
if member is None:
_raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}')
if exp_type != member.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
return AccessStructMember(
VariableReference(node_typ, param),
member,
)
def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression:
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must reference a name')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
varref: Union[ModuleConstantReference, VariableReference]
if node.value.id in our_locals:
param = our_locals[node.value.id]
varref = VariableReference(param.type, param)
elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id]
node_typ = constant_def.type
varref = ModuleConstantReference(node_typ, constant_def)
else:
_raise_static_error(node, f'Undefined variable {node.value.id}')
slice_expr = self.visit_Module_FunctionDef_expr(
module, function, our_locals, module.types['u32'], node.slice.value,
)
if isinstance(node_typ, TypeBytes):
t_u8 = module.types['u8']
if exp_type != t_u8:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{codestyle.expression(slice_expr)}] is actually {codestyle.type_(t_u8)}')
if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant')
return AccessBytesIndex(
t_u8,
varref,
slice_expr,
)
if isinstance(node_typ, TypeTuple):
if not isinstance(slice_expr, ConstantUInt32):
_raise_static_error(node, 'Must subscript using a constant index')
idx = slice_expr.value
if len(node_typ.members) <= idx:
_raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}')
tuple_member = node_typ.members[idx]
if exp_type != tuple_member.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(tuple_member.type)}')
if isinstance(varref, ModuleConstantReference):
raise NotImplementedError(f'{node} from module constant')
return AccessTupleMember(
varref,
tuple_member,
)
if isinstance(node_typ, TypeStaticArray):
if exp_type != node_typ.member_type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}')
if not isinstance(slice_expr, ConstantInt32):
return AccessStaticArrayMember(
varref,
node_typ,
slice_expr,
)
idx = slice_expr.value
if len(node_typ.members) <= idx:
_raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}')
static_array_member = node_typ.members[idx]
return AccessStaticArrayMember(
varref,
node_typ,
static_array_member,
)
_raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}')
def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant:
del module
_not_implemented(node.kind is None, 'Constant.kind')
if isinstance(exp_type, TypeUInt8):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 255:
# _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}')
return ConstantUInt8(exp_type, node.value)
if isinstance(exp_type, TypeUInt32):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 4294967295:
# _raise_static_error(node, 'Integer value out of range')
return ConstantUInt32(exp_type, node.value)
if isinstance(exp_type, TypeUInt64):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < 0 or node.value > 18446744073709551615:
# _raise_static_error(node, 'Integer value out of range')
return ConstantUInt64(exp_type, node.value)
if isinstance(exp_type, TypeInt32):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < -2147483648 or node.value > 2147483647:
# _raise_static_error(node, 'Integer value out of range')
return ConstantInt32(exp_type, node.value)
if isinstance(exp_type, TypeInt64):
# if not isinstance(node.value, int):
# _raise_static_error(node, 'Expected integer value')
#
# if node.value < -9223372036854775808 or node.value > 9223372036854775807:
# _raise_static_error(node, 'Integer value out of range')
return ConstantInt64(exp_type, node.value)
if isinstance(exp_type, TypeFloat32):
if not isinstance(node.value, (float, int, )):
_raise_static_error(node, 'Expected float value')
# FIXME: Range check
return ConstantFloat32(exp_type, node.value)
if isinstance(exp_type, TypeFloat64):
if not isinstance(node.value, (float, int, )):
_raise_static_error(node, 'Expected float value')
# FIXME: Range check
return ConstantFloat64(exp_type, node.value)
raise NotImplementedError(f'{node} as const for type {exp_type}')
def visit_type(self, module: Module, node: ast.expr) -> TypeBase:
if isinstance(node, ast.Constant):
if node.value is None:
return module.types['None']
_raise_static_error(node, f'Unrecognized type {node.value}')
if isinstance(node, ast.Name):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.id in module.types:
return module.types[node.id]
if node.id in module.structs:
return module.structs[node.id]
_raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Subscript):
if not isinstance(node.value, ast.Name):
_raise_static_error(node, 'Must be name')
if not isinstance(node.slice, ast.Index):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice.value, ast.Constant):
_raise_static_error(node, 'Must subscript using a constant index')
if not isinstance(node.slice.value.value, int):
_raise_static_error(node, 'Must subscript using a constant integer index')
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
if node.value.id in module.types:
member_type = module.types[node.value.id]
else:
_raise_static_error(node, f'Unrecognized type {node.value.id}')
type_static_array = TypeStaticArray(member_type)
offset = 0
for idx in range(node.slice.value.value):
static_array_member = TypeStaticArrayMember(idx, offset)
type_static_array.members.append(static_array_member)
offset += member_type.alloc_size()
key = f'{node.value.id}[{node.slice.value.value}]'
if key not in module.types:
module.types[key] = type_static_array
return module.types[key]
if isinstance(node, ast.Tuple):
if not isinstance(node.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
type_tuple = TypeTuple()
offset = 0
for idx, elt in enumerate(node.elts):
tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset)
type_tuple.members.append(tuple_member)
offset += tuple_member.type.alloc_size()
key = type_tuple.render_internal_name()
if key not in module.types:
module.types[key] = type_tuple
constructor = TupleConstructor(type_tuple)
module.functions[constructor.name] = constructor
return module.types[key]
raise NotImplementedError(f'{node} as type')
def _not_implemented(check: Any, msg: str) -> None:
if not check:
raise NotImplementedError(msg)
def _raise_static_error(node: Union[ast.mod, ast.stmt, ast.expr], msg: str) -> NoReturn:
raise StaticError(
f'Static error on line {node.lineno}: {msg}'
)