diff --git a/phasm/codestyle.py b/phasm/codestyle.py index e2e64f0..4d19bd6 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -73,6 +73,12 @@ def struct_definition(inp: typing.TypeStruct) -> str: return result +def constant_definition(inp: ourlang.ModuleConstantDef) -> str: + """ + Render: Module Constant's definition + """ + return f'{inp.name}: {type_(inp.type)} = {expression(inp.constant)}\n' + def expression(inp: ourlang.Expression) -> str: """ Render: A Phasm expression @@ -132,6 +138,9 @@ def expression(inp: ourlang.Expression) -> str: fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' + if isinstance(inp, ourlang.ModuleConstantReference): + return inp.definition.name + raise NotImplementedError(expression, inp) def statement(inp: ourlang.Statement) -> Statements: @@ -199,6 +208,11 @@ def module(inp: ourlang.Module) -> str: result += '\n' result += struct_definition(struct) + for cdef in inp.constant_defs.values(): + if result: + result += '\n' + result += constant_definition(cdef) + for func in inp.functions.values(): if func.lineno < 0: # Buildin (-2) or auto generated (-1) diff --git a/phasm/compiler.py b/phasm/compiler.py index 723da8a..9490ad4 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -275,6 +275,16 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: expression_fold(wgn, inp) return + if isinstance(inp, ourlang.ModuleConstantReference): + mtyp = LOAD_STORE_TYPE_MAP.get(inp.type.__class__) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, inp.type) + + expression(wgn, inp.definition.constant) + return + raise NotImplementedError(expression, inp) def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: @@ -293,9 +303,9 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: wgn.add_statement('nop', comment='acu :: u8') acu_var = wgn.temp_var_u8(f'fold_{codestyle.type_(inp.type)}_acu') wgn.add_statement('nop', comment='adr :: bytes*') - adr_var = wgn.temp_var_i32(f'fold_i32_adr') + adr_var = wgn.temp_var_i32('fold_i32_adr') wgn.add_statement('nop', comment='len :: i32') - len_var = wgn.temp_var_i32(f'fold_i32_len') + len_var = wgn.temp_var_i32('fold_i32_len') wgn.add_statement('nop', comment='acu = base') expression(wgn, inp.base) @@ -345,8 +355,6 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: # return acu wgn.local.get(acu_var) - return - def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None: """ Compile: Return statement diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 4d01b16..4e245a2 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -258,6 +258,18 @@ class Fold(Expression): self.base = base self.iter = iter_ +class ModuleConstantReference(Expression): + """ + An reference to a module constant expression within a statement + """ + __slots__ = ('definition', ) + + definition: 'ModuleConstantDef' + + def __init__(self, type_: TypeBase, definition: 'ModuleConstantDef') -> None: + super().__init__(type_) + self.definition = definition + class Statement: """ A statement within a function @@ -360,15 +372,33 @@ class TupleConstructor(Function): self.tuple = tuple_ +class ModuleConstantDef: + """ + A constant definition within a module + """ + __slots__ = ('name', 'lineno', 'type', 'constant', ) + + name: str + lineno: int + type: TypeBase + constant: Constant + + def __init__(self, name: str, lineno: int, type_: TypeBase, constant: Constant) -> None: + self.name = name + self.lineno = lineno + self.type = type_ + self.constant = constant + class Module: """ A module is a file and consists of functions """ - __slots__ = ('types', 'functions', 'structs', ) + __slots__ = ('types', 'structs', 'constant_defs', 'functions',) types: Dict[str, TypeBase] - functions: Dict[str, Function] structs: Dict[str, TypeStruct] + constant_defs: Dict[str, ModuleConstantDef] + functions: Dict[str, Function] def __init__(self) -> None: self.types = { @@ -382,5 +412,6 @@ class Module: 'f64': TypeFloat64(), 'bytes': TypeBytes(), } - self.functions = {} self.structs = {} + self.constant_defs = {} + self.functions = {} diff --git a/phasm/parser.py b/phasm/parser.py index f1de44f..30d2b26 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -32,16 +32,19 @@ from .ourlang import ( Expression, AccessBytesIndex, AccessStructMember, AccessTupleMember, BinaryOp, + Constant, ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64, ConstantUInt8, ConstantUInt32, ConstantUInt64, FunctionCall, StructConstructor, TupleConstructor, UnaryOp, VariableReference, - Fold, + Fold, ModuleConstantReference, Statement, StatementIf, StatementPass, StatementReturn, + + ModuleConstantDef, ) def phasm_parse(source: str) -> Module: @@ -80,13 +83,13 @@ class OurVisitor: for stmt in node.body: res = self.pre_visit_Module_stmt(module, stmt) - if isinstance(res, Function): - if res.name in module.functions: + if isinstance(res, ModuleConstantDef): + if res.name in module.constant_defs: raise StaticError( - f'{res.name} already defined on line {module.functions[res.name].lineno}' + f'{res.name} already defined on line {module.constant_defs[res.name].lineno}' ) - module.functions[res.name] = res + module.constant_defs[res.name] = res if isinstance(res, TypeStruct): if res.name in module.structs: @@ -98,6 +101,14 @@ class OurVisitor: constructor = StructConstructor(res) module.functions[constructor.name] = constructor + if isinstance(res, Function): + if res.name in module.functions: + raise StaticError( + f'{res.name} already defined on line {module.functions[res.name].lineno}' + ) + + module.functions[res.name] = res + # Second pass for the function bodies for stmt in node.body: @@ -105,13 +116,16 @@ class OurVisitor: return module - def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, TypeStruct]: + def pre_visit_Module_stmt(self, module: Module, node: ast.stmt) -> Union[Function, TypeStruct, ModuleConstantDef]: if isinstance(node, ast.FunctionDef): return self.pre_visit_Module_FunctionDef(module, node) if isinstance(node, ast.ClassDef): return self.pre_visit_Module_ClassDef(module, node) + if isinstance(node, ast.AnnAssign): + return self.pre_visit_Module_AnnAssign(module, node) + raise NotImplementedError(f'{node} on Module') def pre_visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> Function: @@ -184,6 +198,25 @@ class OurVisitor: return struct + def pre_visit_Module_AnnAssign(self, module: Module, node: ast.AnnAssign) -> ModuleConstantDef: + if not isinstance(node.target, ast.Name): + _raise_static_error(node, 'Must be name') + if not isinstance(node.target.ctx, ast.Store): + _raise_static_error(node, 'Must be load context') + if not isinstance(node.value, ast.Constant): + _raise_static_error(node, 'Must be constant') + + exp_type = self.visit_type(module, node.annotation) + + constant = ModuleConstantDef( + node.target.id, + node.lineno, + exp_type, + self.visit_Module_Constant(module, exp_type, node.value) + ) + + return constant + def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None: if isinstance(node, ast.FunctionDef): self.visit_Module_FunctionDef(module, node) @@ -192,6 +225,9 @@ class OurVisitor: if isinstance(node, ast.ClassDef): return + if isinstance(node, ast.AnnAssign): + return + raise NotImplementedError(f'{node} on Module') def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: @@ -308,8 +344,8 @@ class OurVisitor: return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) if isinstance(node, ast.Constant): - return self.visit_Module_FunctionDef_Constant( - module, function, exp_type, node, + return self.visit_Module_Constant( + module, exp_type, node, ) if isinstance(node, ast.Attribute): @@ -326,14 +362,21 @@ class OurVisitor: if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.id not in our_locals: - _raise_static_error(node, 'Undefined variable') + if node.id in our_locals: + act_type = our_locals[node.id] + if exp_type != act_type: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') - act_type = our_locals[node.id] - if exp_type != act_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') + return VariableReference(act_type, node.id) - return VariableReference(act_type, node.id) + if node.id in module.constant_defs: + cdef = module.constant_defs[node.id] + if exp_type != cdef.type: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') + + return ModuleConstantReference(exp_type, cdef) + + _raise_static_error(node, f'Undefined variable {node.id}') if isinstance(node, ast.Tuple): if not isinstance(node.ctx, ast.Load): @@ -541,9 +584,8 @@ class OurVisitor: _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') - def visit_Module_FunctionDef_Constant(self, module: Module, function: Function, exp_type: TypeBase, node: ast.Constant) -> Expression: + def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant: del module - del function _not_implemented(node.kind is None, 'Constant.kind') diff --git a/tests/integration/test_constants.py b/tests/integration/test_constants.py new file mode 100644 index 0000000..1593915 --- /dev/null +++ b/tests/integration/test_constants.py @@ -0,0 +1,17 @@ +import pytest + +from .helpers import Suite + +@pytest.mark.integration_test +def test_return(): + code_py = """ +CONSTANT: i32 = 13 + +@exported +def testEntry() -> i32: + return CONSTANT * 5 +""" + + result = Suite(code_py).run_code() + + assert 65 == result.returned_value