FunctionParam is a class, more framework stuff

This commit is contained in:
Johan B.W. de Vries 2022-09-16 16:42:24 +02:00
parent 7acb2bd8e6
commit 48e16c38b9
8 changed files with 87 additions and 43 deletions

View File

@ -104,7 +104,7 @@ def expression(inp: ourlang.Expression) -> str:
) + ', )' ) + ', )'
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
return str(inp.name) return str(inp.variable.name)
if isinstance(inp, ourlang.UnaryOp): if isinstance(inp, ourlang.UnaryOp):
if ( if (
@ -193,8 +193,8 @@ def function(inp: ourlang.Function) -> str:
result += '@imported\n' result += '@imported\n'
args = ', '.join( args = ', '.join(
f'{x}: {type_(y)}' f'{p.name}: {type_(p.type)}'
for x, y in inp.posonlyargs for p in inp.posonlyargs
) )
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n'

View File

@ -160,7 +160,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
return return
if isinstance(inp, ourlang.VariableReference): if isinstance(inp, ourlang.VariableReference):
wgn.add_statement('local.get', '${}'.format(inp.name)) wgn.add_statement('local.get', '${}'.format(inp.variable.name))
return return
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
@ -450,7 +450,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param:
""" """
Compile: function argument Compile: function argument
""" """
return (inp[0], type_(inp[1]), ) return (inp.name, type_(inp.type), )
def import_(inp: ourlang.Function) -> wasm.Import: def import_(inp: ourlang.Function) -> wasm.Import:
""" """

View File

@ -6,3 +6,6 @@ class StaticError(Exception):
""" """
An error found during static analysis An error found during static analysis
""" """
class TypingError(Exception):
pass

View File

@ -156,13 +156,13 @@ class VariableReference(Expression):
""" """
An variable reference expression within a statement An variable reference expression within a statement
""" """
__slots__ = ('name', ) __slots__ = ('variable', )
name: str variable: 'FunctionParam' # also possibly local
def __init__(self, type_: TypeBase, name: str) -> None: def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None:
super().__init__(type_) super().__init__(type_)
self.name = name self.variable = variable
class UnaryOp(Expression): class UnaryOp(Expression):
""" """
@ -352,7 +352,17 @@ class StatementIf(Statement):
self.statements = [] self.statements = []
self.else_statements = [] self.else_statements = []
FunctionParam = Tuple[str, TypeBase] class FunctionParam:
__slots__ = ('name', 'type', 'type_var', )
name: str
type: TypeBase
type_var: Optional[TypeVar]
def __init__(self, name: str, type_: TypeBase) -> None:
self.name = name
self.type = type_
self.type_var = None
class Function: class Function:
""" """
@ -394,7 +404,7 @@ class StructConstructor(Function):
self.returns = struct self.returns = struct
for mem in struct.members: for mem in struct.members:
self.posonlyargs.append((mem.name, mem.type, )) self.posonlyargs.append(FunctionParam(mem.name, mem.type, ))
self.struct = struct self.struct = struct
@ -414,7 +424,7 @@ class TupleConstructor(Function):
self.returns = tuple_ self.returns = tuple_
for mem in tuple_.members: for mem in tuple_.members:
self.posonlyargs.append((f'arg{mem.idx}', mem.type, )) self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, ))
self.tuple = tuple_ self.tuple = tuple_

View File

@ -48,6 +48,7 @@ from .ourlang import (
Statement, Statement,
StatementIf, StatementPass, StatementReturn, StatementIf, StatementPass, StatementReturn,
FunctionParam,
ModuleConstantDef, ModuleConstantDef,
) )
@ -60,7 +61,7 @@ def phasm_parse(source: str) -> Module:
our_visitor = OurVisitor() our_visitor = OurVisitor()
return our_visitor.visit_Module(res) return our_visitor.visit_Module(res)
OurLocals = Dict[str, TypeBase] OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants?
class OurVisitor: class OurVisitor:
""" """
@ -141,7 +142,7 @@ class OurVisitor:
if not arg.annotation: if not arg.annotation:
_raise_static_error(node, 'Type is required') _raise_static_error(node, 'Type is required')
function.posonlyargs.append(( function.posonlyargs.append(FunctionParam(
arg.arg, arg.arg,
self.visit_type(module, arg.annotation), self.visit_type(module, arg.annotation),
)) ))
@ -297,7 +298,10 @@ class OurVisitor:
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
function = module.functions[node.name] function = module.functions[node.name]
our_locals = dict(function.posonlyargs) our_locals: OurLocals = {
x.name: x
for x in function.posonlyargs
}
for stmt in node.body: for stmt in node.body:
function.statements.append( function.statements.append(
@ -427,11 +431,11 @@ class OurVisitor:
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
if node.id in our_locals: if node.id in our_locals:
act_type = our_locals[node.id] param = our_locals[node.id]
if exp_type != act_type: if exp_type != param.type:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}')
return VariableReference(act_type, node.id) return VariableReference(param.type, param)
if node.id in module.constant_defs: if node.id in module.constant_defs:
cdef = module.constant_defs[node.id] cdef = module.constant_defs[node.id]
@ -541,10 +545,10 @@ class OurVisitor:
if exp_type.__class__ != func.returns.__class__: if exp_type.__class__ != func.returns.__class__:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
if func.returns.__class__ != func.posonlyargs[0][1].__class__: if func.returns.__class__ != func.posonlyargs[0].type.__class__:
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}') _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}')
if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__: if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__:
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
return Fold( return Fold(
@ -568,8 +572,8 @@ class OurVisitor:
result = FunctionCall(func) result = FunctionCall(func)
result.arguments.extend( result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr)
for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) for arg_expr, param in zip(node.args, func.posonlyargs)
) )
return result return result
@ -586,7 +590,9 @@ class OurVisitor:
if not node.value.id in our_locals: if not node.value.id in our_locals:
_raise_static_error(node, f'Undefined variable {node.value.id}') _raise_static_error(node, f'Undefined variable {node.value.id}')
node_typ = our_locals[node.value.id] param = our_locals[node.value.id]
node_typ = param.type
if not isinstance(node_typ, TypeStruct): if not isinstance(node_typ, TypeStruct):
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
@ -598,7 +604,7 @@ class OurVisitor:
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
return AccessStructMember( return AccessStructMember(
VariableReference(node_typ, node.value.id), VariableReference(node_typ, param),
member, member,
) )
@ -614,8 +620,8 @@ class OurVisitor:
varref: Union[ModuleConstantReference, VariableReference] varref: Union[ModuleConstantReference, VariableReference]
if node.value.id in our_locals: if node.value.id in our_locals:
node_typ = our_locals[node.value.id] param = our_locals[node.value.id]
varref = VariableReference(node_typ, node.value.id) varref = VariableReference(param.type, param)
elif node.value.id in module.constant_defs: elif node.value.id in module.constant_defs:
constant_def = module.constant_defs[node.value.id] constant_def = module.constant_defs[node.value.id]
node_typ = constant_def.type node_typ = constant_def.type

View File

@ -1,8 +1,6 @@
""" """
Type checks and enriches the given ast Type checks and enriches the given ast
""" """
from math import ceil, log2
from . import ourlang from . import ourlang
from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar
@ -32,7 +30,14 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
if isinstance(inp, ourlang.Constant): if isinstance(inp, ourlang.Constant):
return constant(ctx, inp) return constant(ctx, inp)
if isinstance(inp, ourlang.VariableReference):
assert inp.variable.type_var is not None, inp
return inp.variable.type_var
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
if inp.operator not in ('+', '-', '|', '&', '^'):
raise NotImplementedError(expression, inp, inp.operator)
left = expression(ctx, inp.left) left = expression(ctx, inp.left)
right = expression(ctx, inp.right) right = expression(ctx, inp.right)
ctx.unify(left, right) ctx.unify(left, right)
@ -41,11 +46,11 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def function(ctx: 'Context', inp: ourlang.Function) -> None: def function(ctx: 'Context', inp: ourlang.Function) -> None:
bctx = ctx.clone() # Clone whenever we go into a block for param in inp.posonlyargs:
param.type_var = _convert_old_type(ctx, param.type)
assert len(inp.statements) == 1 # TODO if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
raise NotImplementedError('Functions with not just a return statement')
assert isinstance(inp.statements[0], ourlang.StatementReturn)
typ = expression(ctx, inp.statements[0].value) typ = expression(ctx, inp.statements[0].value)
ctx.unify(_convert_old_type(ctx, inp.returns), typ) ctx.unify(_convert_old_type(ctx, inp.returns), typ)
@ -62,10 +67,34 @@ from . import typing
def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar: def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar:
result = ctx.new_var() result = ctx.new_var()
if isinstance(inp, typing.TypeUInt8):
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u8')
return result
if isinstance(inp, typing.TypeUInt32): if isinstance(inp, typing.TypeUInt32):
result.add_constraint(TypeConstraintBitWidth(maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location('u32') result.add_location('u32')
return result return result
if isinstance(inp, typing.TypeUInt64):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u64')
return result
if isinstance(inp, typing.TypeInt32):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i32')
return result
if isinstance(inp, typing.TypeInt64):
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i64')
return result
raise NotImplementedError(_convert_old_type, inp) raise NotImplementedError(_convert_old_type, inp)

View File

@ -3,6 +3,8 @@ The phasm type system
""" """
from typing import Dict, Optional, List, Type from typing import Dict, Optional, List, Type
from .exceptions import TypingError
class TypeBase: class TypeBase:
""" """
TypeBase base class TypeBase base class
@ -203,9 +205,6 @@ class TypeStruct(TypeBase):
## NEW STUFF BELOW ## NEW STUFF BELOW
class TypingError(Exception):
pass
class TypingNarrowProtoError(TypingError): class TypingNarrowProtoError(TypingError):
pass pass
@ -237,7 +236,7 @@ class TypeConstraintSigned(TypeConstraintBase):
return TypeConstraintSigned(other.signed) return TypeConstraintSigned(other.signed)
if self.signed is not other.signed: if self.signed is not other.signed:
raise TypeError() raise TypingNarrowProtoError('Signed does not match')
return TypeConstraintSigned(self.signed) return TypeConstraintSigned(self.signed)
@ -300,9 +299,6 @@ class TypeVar:
) )
class Context: class Context:
def clone(self) -> 'Context':
return self # TODO: STUB
def new_var(self) -> TypeVar: def new_var(self) -> TypeVar:
return TypeVar(self) return TypeVar(self)

View File

@ -45,7 +45,7 @@ class RunnerBase:
try: try:
phasm_type(self.phasm_ast) phasm_type(self.phasm_ast)
except NotImplementedError as exc: except NotImplementedError as exc:
warnings.warn(f'phash_type throws an NotImplementedError on this test: {exc}') warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}')
def compile_ast(self) -> None: def compile_ast(self) -> None:
""" """