From 48e16c38b9f89b23ce99bba6005b9bcfe864c36d Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Fri, 16 Sep 2022 16:42:24 +0200 Subject: [PATCH] FunctionParam is a class, more framework stuff --- phasm/codestyle.py | 6 ++--- phasm/compiler.py | 4 ++-- phasm/exceptions.py | 3 +++ phasm/ourlang.py | 24 ++++++++++++++------ phasm/parser.py | 38 +++++++++++++++++-------------- phasm/typer.py | 43 ++++++++++++++++++++++++++++++------ phasm/typing.py | 10 +++------ tests/integration/runners.py | 2 +- 8 files changed, 87 insertions(+), 43 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 4b38e32..d0e9c70 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -104,7 +104,7 @@ def expression(inp: ourlang.Expression) -> str: ) + ', )' if isinstance(inp, ourlang.VariableReference): - return str(inp.name) + return str(inp.variable.name) if isinstance(inp, ourlang.UnaryOp): if ( @@ -193,8 +193,8 @@ def function(inp: ourlang.Function) -> str: result += '@imported\n' args = ', '.join( - f'{x}: {type_(y)}' - for x, y in inp.posonlyargs + f'{p.name}: {type_(p.type)}' + for p in inp.posonlyargs ) result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' diff --git a/phasm/compiler.py b/phasm/compiler.py index 4ba5d5a..d36561d 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -160,7 +160,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: return if isinstance(inp, ourlang.VariableReference): - wgn.add_statement('local.get', '${}'.format(inp.name)) + wgn.add_statement('local.get', '${}'.format(inp.variable.name)) return if isinstance(inp, ourlang.BinaryOp): @@ -450,7 +450,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param: """ Compile: function argument """ - return (inp[0], type_(inp[1]), ) + return (inp.name, type_(inp.type), ) def import_(inp: ourlang.Function) -> wasm.Import: """ diff --git a/phasm/exceptions.py b/phasm/exceptions.py index b459c22..abbcdd2 100644 --- a/phasm/exceptions.py +++ b/phasm/exceptions.py @@ -6,3 +6,6 @@ class StaticError(Exception): """ An error found during static analysis """ + +class TypingError(Exception): + pass diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 3338ce4..6ae9518 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -156,13 +156,13 @@ class VariableReference(Expression): """ An variable reference expression within a statement """ - __slots__ = ('name', ) + __slots__ = ('variable', ) - name: str + variable: 'FunctionParam' # also possibly local - def __init__(self, type_: TypeBase, name: str) -> None: + def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None: super().__init__(type_) - self.name = name + self.variable = variable class UnaryOp(Expression): """ @@ -352,7 +352,17 @@ class StatementIf(Statement): self.statements = [] self.else_statements = [] -FunctionParam = Tuple[str, TypeBase] +class FunctionParam: + __slots__ = ('name', 'type', 'type_var', ) + + name: str + type: TypeBase + type_var: Optional[TypeVar] + + def __init__(self, name: str, type_: TypeBase) -> None: + self.name = name + self.type = type_ + self.type_var = None class Function: """ @@ -394,7 +404,7 @@ class StructConstructor(Function): self.returns = struct for mem in struct.members: - self.posonlyargs.append((mem.name, mem.type, )) + self.posonlyargs.append(FunctionParam(mem.name, mem.type, )) self.struct = struct @@ -414,7 +424,7 @@ class TupleConstructor(Function): self.returns = tuple_ for mem in tuple_.members: - self.posonlyargs.append((f'arg{mem.idx}', mem.type, )) + self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, )) self.tuple = tuple_ diff --git a/phasm/parser.py b/phasm/parser.py index d95bfce..2d7a158 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -48,6 +48,7 @@ from .ourlang import ( Statement, StatementIf, StatementPass, StatementReturn, + FunctionParam, ModuleConstantDef, ) @@ -60,7 +61,7 @@ def phasm_parse(source: str) -> Module: our_visitor = OurVisitor() return our_visitor.visit_Module(res) -OurLocals = Dict[str, TypeBase] +OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants? class OurVisitor: """ @@ -141,7 +142,7 @@ class OurVisitor: if not arg.annotation: _raise_static_error(node, 'Type is required') - function.posonlyargs.append(( + function.posonlyargs.append(FunctionParam( arg.arg, self.visit_type(module, arg.annotation), )) @@ -297,7 +298,10 @@ class OurVisitor: def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: function = module.functions[node.name] - our_locals = dict(function.posonlyargs) + our_locals: OurLocals = { + x.name: x + for x in function.posonlyargs + } for stmt in node.body: function.statements.append( @@ -427,11 +431,11 @@ class OurVisitor: _raise_static_error(node, 'Must be load context') if node.id in our_locals: - act_type = our_locals[node.id] - if exp_type != act_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') + param = our_locals[node.id] + if exp_type != param.type: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}') - return VariableReference(act_type, node.id) + return VariableReference(param.type, param) if node.id in module.constant_defs: cdef = module.constant_defs[node.id] @@ -541,10 +545,10 @@ class OurVisitor: if exp_type.__class__ != func.returns.__class__: _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') - if func.returns.__class__ != func.posonlyargs[0][1].__class__: - _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}') + if func.returns.__class__ != func.posonlyargs[0].type.__class__: + _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}') - if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__: + if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__: _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') return Fold( @@ -568,8 +572,8 @@ class OurVisitor: result = FunctionCall(func) result.arguments.extend( - self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) - for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) + self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr) + for arg_expr, param in zip(node.args, func.posonlyargs) ) return result @@ -586,7 +590,9 @@ class OurVisitor: if not node.value.id in our_locals: _raise_static_error(node, f'Undefined variable {node.value.id}') - node_typ = our_locals[node.value.id] + param = our_locals[node.value.id] + + node_typ = param.type if not isinstance(node_typ, TypeStruct): _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') @@ -598,7 +604,7 @@ class OurVisitor: _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') return AccessStructMember( - VariableReference(node_typ, node.value.id), + VariableReference(node_typ, param), member, ) @@ -614,8 +620,8 @@ class OurVisitor: varref: Union[ModuleConstantReference, VariableReference] if node.value.id in our_locals: - node_typ = our_locals[node.value.id] - varref = VariableReference(node_typ, node.value.id) + param = our_locals[node.value.id] + varref = VariableReference(param.type, param) elif node.value.id in module.constant_defs: constant_def = module.constant_defs[node.value.id] node_typ = constant_def.type diff --git a/phasm/typer.py b/phasm/typer.py index d21f823..c0677fb 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -1,8 +1,6 @@ """ Type checks and enriches the given ast """ -from math import ceil, log2 - from . import ourlang from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar @@ -32,7 +30,14 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.Constant): return constant(ctx, inp) + if isinstance(inp, ourlang.VariableReference): + assert inp.variable.type_var is not None, inp + return inp.variable.type_var + if isinstance(inp, ourlang.BinaryOp): + if inp.operator not in ('+', '-', '|', '&', '^'): + raise NotImplementedError(expression, inp, inp.operator) + left = expression(ctx, inp.left) right = expression(ctx, inp.right) ctx.unify(left, right) @@ -41,11 +46,11 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': raise NotImplementedError(expression, inp) def function(ctx: 'Context', inp: ourlang.Function) -> None: - bctx = ctx.clone() # Clone whenever we go into a block + for param in inp.posonlyargs: + param.type_var = _convert_old_type(ctx, param.type) - assert len(inp.statements) == 1 # TODO - - assert isinstance(inp.statements[0], ourlang.StatementReturn) + if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn): + raise NotImplementedError('Functions with not just a return statement') typ = expression(ctx, inp.statements[0].value) ctx.unify(_convert_old_type(ctx, inp.returns), typ) @@ -62,10 +67,34 @@ from . import typing def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar: result = ctx.new_var() + if isinstance(inp, typing.TypeUInt8): + result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u8') + return result + if isinstance(inp, typing.TypeUInt32): - result.add_constraint(TypeConstraintBitWidth(maxb=32)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(False)) result.add_location('u32') return result + if isinstance(inp, typing.TypeUInt64): + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u64') + return result + + if isinstance(inp, typing.TypeInt32): + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location('i32') + return result + + if isinstance(inp, typing.TypeInt64): + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location('i64') + return result + raise NotImplementedError(_convert_old_type, inp) diff --git a/phasm/typing.py b/phasm/typing.py index 0cb213d..65bb095 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -3,6 +3,8 @@ The phasm type system """ from typing import Dict, Optional, List, Type +from .exceptions import TypingError + class TypeBase: """ TypeBase base class @@ -203,9 +205,6 @@ class TypeStruct(TypeBase): ## NEW STUFF BELOW -class TypingError(Exception): - pass - class TypingNarrowProtoError(TypingError): pass @@ -237,7 +236,7 @@ class TypeConstraintSigned(TypeConstraintBase): return TypeConstraintSigned(other.signed) if self.signed is not other.signed: - raise TypeError() + raise TypingNarrowProtoError('Signed does not match') return TypeConstraintSigned(self.signed) @@ -300,9 +299,6 @@ class TypeVar: ) class Context: - def clone(self) -> 'Context': - return self # TODO: STUB - def new_var(self) -> TypeVar: return TypeVar(self) diff --git a/tests/integration/runners.py b/tests/integration/runners.py index 57d2adb..005d44e 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -45,7 +45,7 @@ class RunnerBase: try: phasm_type(self.phasm_ast) except NotImplementedError as exc: - warnings.warn(f'phash_type throws an NotImplementedError on this test: {exc}') + warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}') def compile_ast(self) -> None: """