FunctionParam is a class, more framework stuff
This commit is contained in:
parent
7acb2bd8e6
commit
48e16c38b9
@ -104,7 +104,7 @@ def expression(inp: ourlang.Expression) -> str:
|
|||||||
) + ', )'
|
) + ', )'
|
||||||
|
|
||||||
if isinstance(inp, ourlang.VariableReference):
|
if isinstance(inp, ourlang.VariableReference):
|
||||||
return str(inp.name)
|
return str(inp.variable.name)
|
||||||
|
|
||||||
if isinstance(inp, ourlang.UnaryOp):
|
if isinstance(inp, ourlang.UnaryOp):
|
||||||
if (
|
if (
|
||||||
@ -193,8 +193,8 @@ def function(inp: ourlang.Function) -> str:
|
|||||||
result += '@imported\n'
|
result += '@imported\n'
|
||||||
|
|
||||||
args = ', '.join(
|
args = ', '.join(
|
||||||
f'{x}: {type_(y)}'
|
f'{p.name}: {type_(p.type)}'
|
||||||
for x, y in inp.posonlyargs
|
for p in inp.posonlyargs
|
||||||
)
|
)
|
||||||
|
|
||||||
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n'
|
result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n'
|
||||||
|
|||||||
@ -160,7 +160,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.VariableReference):
|
if isinstance(inp, ourlang.VariableReference):
|
||||||
wgn.add_statement('local.get', '${}'.format(inp.name))
|
wgn.add_statement('local.get', '${}'.format(inp.variable.name))
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.BinaryOp):
|
if isinstance(inp, ourlang.BinaryOp):
|
||||||
@ -450,7 +450,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param:
|
|||||||
"""
|
"""
|
||||||
Compile: function argument
|
Compile: function argument
|
||||||
"""
|
"""
|
||||||
return (inp[0], type_(inp[1]), )
|
return (inp.name, type_(inp.type), )
|
||||||
|
|
||||||
def import_(inp: ourlang.Function) -> wasm.Import:
|
def import_(inp: ourlang.Function) -> wasm.Import:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -6,3 +6,6 @@ class StaticError(Exception):
|
|||||||
"""
|
"""
|
||||||
An error found during static analysis
|
An error found during static analysis
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class TypingError(Exception):
|
||||||
|
pass
|
||||||
|
|||||||
@ -156,13 +156,13 @@ class VariableReference(Expression):
|
|||||||
"""
|
"""
|
||||||
An variable reference expression within a statement
|
An variable reference expression within a statement
|
||||||
"""
|
"""
|
||||||
__slots__ = ('name', )
|
__slots__ = ('variable', )
|
||||||
|
|
||||||
name: str
|
variable: 'FunctionParam' # also possibly local
|
||||||
|
|
||||||
def __init__(self, type_: TypeBase, name: str) -> None:
|
def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None:
|
||||||
super().__init__(type_)
|
super().__init__(type_)
|
||||||
self.name = name
|
self.variable = variable
|
||||||
|
|
||||||
class UnaryOp(Expression):
|
class UnaryOp(Expression):
|
||||||
"""
|
"""
|
||||||
@ -352,7 +352,17 @@ class StatementIf(Statement):
|
|||||||
self.statements = []
|
self.statements = []
|
||||||
self.else_statements = []
|
self.else_statements = []
|
||||||
|
|
||||||
FunctionParam = Tuple[str, TypeBase]
|
class FunctionParam:
|
||||||
|
__slots__ = ('name', 'type', 'type_var', )
|
||||||
|
|
||||||
|
name: str
|
||||||
|
type: TypeBase
|
||||||
|
type_var: Optional[TypeVar]
|
||||||
|
|
||||||
|
def __init__(self, name: str, type_: TypeBase) -> None:
|
||||||
|
self.name = name
|
||||||
|
self.type = type_
|
||||||
|
self.type_var = None
|
||||||
|
|
||||||
class Function:
|
class Function:
|
||||||
"""
|
"""
|
||||||
@ -394,7 +404,7 @@ class StructConstructor(Function):
|
|||||||
self.returns = struct
|
self.returns = struct
|
||||||
|
|
||||||
for mem in struct.members:
|
for mem in struct.members:
|
||||||
self.posonlyargs.append((mem.name, mem.type, ))
|
self.posonlyargs.append(FunctionParam(mem.name, mem.type, ))
|
||||||
|
|
||||||
self.struct = struct
|
self.struct = struct
|
||||||
|
|
||||||
@ -414,7 +424,7 @@ class TupleConstructor(Function):
|
|||||||
self.returns = tuple_
|
self.returns = tuple_
|
||||||
|
|
||||||
for mem in tuple_.members:
|
for mem in tuple_.members:
|
||||||
self.posonlyargs.append((f'arg{mem.idx}', mem.type, ))
|
self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, ))
|
||||||
|
|
||||||
self.tuple = tuple_
|
self.tuple = tuple_
|
||||||
|
|
||||||
|
|||||||
@ -48,6 +48,7 @@ from .ourlang import (
|
|||||||
Statement,
|
Statement,
|
||||||
StatementIf, StatementPass, StatementReturn,
|
StatementIf, StatementPass, StatementReturn,
|
||||||
|
|
||||||
|
FunctionParam,
|
||||||
ModuleConstantDef,
|
ModuleConstantDef,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -60,7 +61,7 @@ def phasm_parse(source: str) -> Module:
|
|||||||
our_visitor = OurVisitor()
|
our_visitor = OurVisitor()
|
||||||
return our_visitor.visit_Module(res)
|
return our_visitor.visit_Module(res)
|
||||||
|
|
||||||
OurLocals = Dict[str, TypeBase]
|
OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants?
|
||||||
|
|
||||||
class OurVisitor:
|
class OurVisitor:
|
||||||
"""
|
"""
|
||||||
@ -141,7 +142,7 @@ class OurVisitor:
|
|||||||
if not arg.annotation:
|
if not arg.annotation:
|
||||||
_raise_static_error(node, 'Type is required')
|
_raise_static_error(node, 'Type is required')
|
||||||
|
|
||||||
function.posonlyargs.append((
|
function.posonlyargs.append(FunctionParam(
|
||||||
arg.arg,
|
arg.arg,
|
||||||
self.visit_type(module, arg.annotation),
|
self.visit_type(module, arg.annotation),
|
||||||
))
|
))
|
||||||
@ -297,7 +298,10 @@ class OurVisitor:
|
|||||||
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
|
def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None:
|
||||||
function = module.functions[node.name]
|
function = module.functions[node.name]
|
||||||
|
|
||||||
our_locals = dict(function.posonlyargs)
|
our_locals: OurLocals = {
|
||||||
|
x.name: x
|
||||||
|
for x in function.posonlyargs
|
||||||
|
}
|
||||||
|
|
||||||
for stmt in node.body:
|
for stmt in node.body:
|
||||||
function.statements.append(
|
function.statements.append(
|
||||||
@ -427,11 +431,11 @@ class OurVisitor:
|
|||||||
_raise_static_error(node, 'Must be load context')
|
_raise_static_error(node, 'Must be load context')
|
||||||
|
|
||||||
if node.id in our_locals:
|
if node.id in our_locals:
|
||||||
act_type = our_locals[node.id]
|
param = our_locals[node.id]
|
||||||
if exp_type != act_type:
|
if exp_type != param.type:
|
||||||
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}')
|
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}')
|
||||||
|
|
||||||
return VariableReference(act_type, node.id)
|
return VariableReference(param.type, param)
|
||||||
|
|
||||||
if node.id in module.constant_defs:
|
if node.id in module.constant_defs:
|
||||||
cdef = module.constant_defs[node.id]
|
cdef = module.constant_defs[node.id]
|
||||||
@ -541,10 +545,10 @@ class OurVisitor:
|
|||||||
if exp_type.__class__ != func.returns.__class__:
|
if exp_type.__class__ != func.returns.__class__:
|
||||||
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
|
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}')
|
||||||
|
|
||||||
if func.returns.__class__ != func.posonlyargs[0][1].__class__:
|
if func.returns.__class__ != func.posonlyargs[0].type.__class__:
|
||||||
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}')
|
_raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}')
|
||||||
|
|
||||||
if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__:
|
if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__:
|
||||||
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
|
_raise_static_error(node, 'Only folding over bytes (u8) is supported at this time')
|
||||||
|
|
||||||
return Fold(
|
return Fold(
|
||||||
@ -568,8 +572,8 @@ class OurVisitor:
|
|||||||
|
|
||||||
result = FunctionCall(func)
|
result = FunctionCall(func)
|
||||||
result.arguments.extend(
|
result.arguments.extend(
|
||||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr)
|
self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr)
|
||||||
for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs)
|
for arg_expr, param in zip(node.args, func.posonlyargs)
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -586,7 +590,9 @@ class OurVisitor:
|
|||||||
if not node.value.id in our_locals:
|
if not node.value.id in our_locals:
|
||||||
_raise_static_error(node, f'Undefined variable {node.value.id}')
|
_raise_static_error(node, f'Undefined variable {node.value.id}')
|
||||||
|
|
||||||
node_typ = our_locals[node.value.id]
|
param = our_locals[node.value.id]
|
||||||
|
|
||||||
|
node_typ = param.type
|
||||||
if not isinstance(node_typ, TypeStruct):
|
if not isinstance(node_typ, TypeStruct):
|
||||||
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
|
_raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}')
|
||||||
|
|
||||||
@ -598,7 +604,7 @@ class OurVisitor:
|
|||||||
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
|
_raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}')
|
||||||
|
|
||||||
return AccessStructMember(
|
return AccessStructMember(
|
||||||
VariableReference(node_typ, node.value.id),
|
VariableReference(node_typ, param),
|
||||||
member,
|
member,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -614,8 +620,8 @@ class OurVisitor:
|
|||||||
|
|
||||||
varref: Union[ModuleConstantReference, VariableReference]
|
varref: Union[ModuleConstantReference, VariableReference]
|
||||||
if node.value.id in our_locals:
|
if node.value.id in our_locals:
|
||||||
node_typ = our_locals[node.value.id]
|
param = our_locals[node.value.id]
|
||||||
varref = VariableReference(node_typ, node.value.id)
|
varref = VariableReference(param.type, param)
|
||||||
elif node.value.id in module.constant_defs:
|
elif node.value.id in module.constant_defs:
|
||||||
constant_def = module.constant_defs[node.value.id]
|
constant_def = module.constant_defs[node.value.id]
|
||||||
node_typ = constant_def.type
|
node_typ = constant_def.type
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Type checks and enriches the given ast
|
Type checks and enriches the given ast
|
||||||
"""
|
"""
|
||||||
from math import ceil, log2
|
|
||||||
|
|
||||||
from . import ourlang
|
from . import ourlang
|
||||||
|
|
||||||
from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar
|
from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar
|
||||||
@ -32,7 +30,14 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
|
|||||||
if isinstance(inp, ourlang.Constant):
|
if isinstance(inp, ourlang.Constant):
|
||||||
return constant(ctx, inp)
|
return constant(ctx, inp)
|
||||||
|
|
||||||
|
if isinstance(inp, ourlang.VariableReference):
|
||||||
|
assert inp.variable.type_var is not None, inp
|
||||||
|
return inp.variable.type_var
|
||||||
|
|
||||||
if isinstance(inp, ourlang.BinaryOp):
|
if isinstance(inp, ourlang.BinaryOp):
|
||||||
|
if inp.operator not in ('+', '-', '|', '&', '^'):
|
||||||
|
raise NotImplementedError(expression, inp, inp.operator)
|
||||||
|
|
||||||
left = expression(ctx, inp.left)
|
left = expression(ctx, inp.left)
|
||||||
right = expression(ctx, inp.right)
|
right = expression(ctx, inp.right)
|
||||||
ctx.unify(left, right)
|
ctx.unify(left, right)
|
||||||
@ -41,11 +46,11 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar':
|
|||||||
raise NotImplementedError(expression, inp)
|
raise NotImplementedError(expression, inp)
|
||||||
|
|
||||||
def function(ctx: 'Context', inp: ourlang.Function) -> None:
|
def function(ctx: 'Context', inp: ourlang.Function) -> None:
|
||||||
bctx = ctx.clone() # Clone whenever we go into a block
|
for param in inp.posonlyargs:
|
||||||
|
param.type_var = _convert_old_type(ctx, param.type)
|
||||||
|
|
||||||
assert len(inp.statements) == 1 # TODO
|
if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn):
|
||||||
|
raise NotImplementedError('Functions with not just a return statement')
|
||||||
assert isinstance(inp.statements[0], ourlang.StatementReturn)
|
|
||||||
typ = expression(ctx, inp.statements[0].value)
|
typ = expression(ctx, inp.statements[0].value)
|
||||||
|
|
||||||
ctx.unify(_convert_old_type(ctx, inp.returns), typ)
|
ctx.unify(_convert_old_type(ctx, inp.returns), typ)
|
||||||
@ -62,10 +67,34 @@ from . import typing
|
|||||||
def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar:
|
def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar:
|
||||||
result = ctx.new_var()
|
result = ctx.new_var()
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeUInt8):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
|
||||||
|
result.add_constraint(TypeConstraintSigned(False))
|
||||||
|
result.add_location('u8')
|
||||||
|
return result
|
||||||
|
|
||||||
if isinstance(inp, typing.TypeUInt32):
|
if isinstance(inp, typing.TypeUInt32):
|
||||||
result.add_constraint(TypeConstraintBitWidth(maxb=32))
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
||||||
result.add_constraint(TypeConstraintSigned(False))
|
result.add_constraint(TypeConstraintSigned(False))
|
||||||
result.add_location('u32')
|
result.add_location('u32')
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeUInt64):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64))
|
||||||
|
result.add_constraint(TypeConstraintSigned(False))
|
||||||
|
result.add_location('u64')
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeInt32):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
||||||
|
result.add_constraint(TypeConstraintSigned(True))
|
||||||
|
result.add_location('i32')
|
||||||
|
return result
|
||||||
|
|
||||||
|
if isinstance(inp, typing.TypeInt64):
|
||||||
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64))
|
||||||
|
result.add_constraint(TypeConstraintSigned(True))
|
||||||
|
result.add_location('i64')
|
||||||
|
return result
|
||||||
|
|
||||||
raise NotImplementedError(_convert_old_type, inp)
|
raise NotImplementedError(_convert_old_type, inp)
|
||||||
|
|||||||
@ -3,6 +3,8 @@ The phasm type system
|
|||||||
"""
|
"""
|
||||||
from typing import Dict, Optional, List, Type
|
from typing import Dict, Optional, List, Type
|
||||||
|
|
||||||
|
from .exceptions import TypingError
|
||||||
|
|
||||||
class TypeBase:
|
class TypeBase:
|
||||||
"""
|
"""
|
||||||
TypeBase base class
|
TypeBase base class
|
||||||
@ -203,9 +205,6 @@ class TypeStruct(TypeBase):
|
|||||||
|
|
||||||
## NEW STUFF BELOW
|
## NEW STUFF BELOW
|
||||||
|
|
||||||
class TypingError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class TypingNarrowProtoError(TypingError):
|
class TypingNarrowProtoError(TypingError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -237,7 +236,7 @@ class TypeConstraintSigned(TypeConstraintBase):
|
|||||||
return TypeConstraintSigned(other.signed)
|
return TypeConstraintSigned(other.signed)
|
||||||
|
|
||||||
if self.signed is not other.signed:
|
if self.signed is not other.signed:
|
||||||
raise TypeError()
|
raise TypingNarrowProtoError('Signed does not match')
|
||||||
|
|
||||||
return TypeConstraintSigned(self.signed)
|
return TypeConstraintSigned(self.signed)
|
||||||
|
|
||||||
@ -300,9 +299,6 @@ class TypeVar:
|
|||||||
)
|
)
|
||||||
|
|
||||||
class Context:
|
class Context:
|
||||||
def clone(self) -> 'Context':
|
|
||||||
return self # TODO: STUB
|
|
||||||
|
|
||||||
def new_var(self) -> TypeVar:
|
def new_var(self) -> TypeVar:
|
||||||
return TypeVar(self)
|
return TypeVar(self)
|
||||||
|
|
||||||
|
|||||||
@ -45,7 +45,7 @@ class RunnerBase:
|
|||||||
try:
|
try:
|
||||||
phasm_type(self.phasm_ast)
|
phasm_type(self.phasm_ast)
|
||||||
except NotImplementedError as exc:
|
except NotImplementedError as exc:
|
||||||
warnings.warn(f'phash_type throws an NotImplementedError on this test: {exc}')
|
warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}')
|
||||||
|
|
||||||
def compile_ast(self) -> None:
|
def compile_ast(self) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user