From 7acb2bd8e6cfc23eaf9b3a9dfee959745d6562da Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Fri, 16 Sep 2022 14:42:40 +0200 Subject: [PATCH 01/18] Framework sketch --- phasm/ourlang.py | 6 +- phasm/typer.py | 71 +++++++++++++++++++ phasm/typing.py | 132 ++++++++++++++++++++++++++++++++++- tests/integration/runners.py | 6 ++ 4 files changed, 213 insertions(+), 2 deletions(-) create mode 100644 phasm/typer.py diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 5f19d2e..3338ce4 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -21,18 +21,22 @@ from .typing import ( TypeTuple, TypeTupleMember, TypeStaticArray, TypeStaticArrayMember, TypeStruct, TypeStructMember, + + TypeVar, ) class Expression: """ An expression within a statement """ - __slots__ = ('type', ) + __slots__ = ('type', 'type_var', ) type: TypeBase + type_var: Optional[TypeVar] def __init__(self, type_: TypeBase) -> None: self.type = type_ + self.type_var = None class Constant(Expression): """ diff --git a/phasm/typer.py b/phasm/typer.py new file mode 100644 index 0000000..d21f823 --- /dev/null +++ b/phasm/typer.py @@ -0,0 +1,71 @@ +""" +Type checks and enriches the given ast +""" +from math import ceil, log2 + +from . import ourlang + +from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar + +def phasm_type(inp: ourlang.Module) -> None: + module(inp) + +def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': + if getattr(inp, 'value', int): + result = ctx.new_var() + + # Need at least this many bits to store this constant value + result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # type: ignore + # Don't dictate anything about signedness - you can use a signed + # constant in an unsigned variable if the bits fit + result.add_constraint(TypeConstraintSigned(None)) + + result.add_location(str(inp.value)) # type: ignore + + inp.type_var = result + + return result + + raise NotImplementedError(constant, inp) + +def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': + if isinstance(inp, ourlang.Constant): + return constant(ctx, inp) + + if isinstance(inp, ourlang.BinaryOp): + left = expression(ctx, inp.left) + right = expression(ctx, inp.right) + ctx.unify(left, right) + return left + + raise NotImplementedError(expression, inp) + +def function(ctx: 'Context', inp: ourlang.Function) -> None: + bctx = ctx.clone() # Clone whenever we go into a block + + assert len(inp.statements) == 1 # TODO + + assert isinstance(inp.statements[0], ourlang.StatementReturn) + typ = expression(ctx, inp.statements[0].value) + + ctx.unify(_convert_old_type(ctx, inp.returns), typ) + return + +def module(inp: ourlang.Module) -> None: + ctx = Context() + + for func in inp.functions.values(): + function(ctx, func) + +from . import typing + +def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar: + result = ctx.new_var() + + if isinstance(inp, typing.TypeUInt32): + result.add_constraint(TypeConstraintBitWidth(maxb=32)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u32') + return result + + raise NotImplementedError(_convert_old_type, inp) diff --git a/phasm/typing.py b/phasm/typing.py index e56f7a9..0cb213d 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -1,7 +1,7 @@ """ The phasm type system """ -from typing import Optional, List +from typing import Dict, Optional, List, Type class TypeBase: """ @@ -200,3 +200,133 @@ class TypeStruct(TypeBase): x.type.alloc_size() for x in self.members ) + +## NEW STUFF BELOW + +class TypingError(Exception): + pass + +class TypingNarrowProtoError(TypingError): + pass + +class TypingNarrowError(TypingError): + def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None: + super().__init__( + f'Cannot narrow types {l} and {r}: {msg}' + ) + +class TypeConstraintBase: + def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': + raise NotImplementedError('narrow', self, other) + +class TypeConstraintSigned(TypeConstraintBase): + __slots__ = ('signed', ) + + signed: Optional[bool] + + def __init__(self, signed: Optional[bool]) -> None: + self.signed = signed + + def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintSigned': + if not isinstance(other, TypeConstraintSigned): + raise Exception('Invalid comparison') + + if other.signed is None: + return TypeConstraintSigned(self.signed) + if self.signed is None: + return TypeConstraintSigned(other.signed) + + if self.signed is not other.signed: + raise TypeError() + + return TypeConstraintSigned(self.signed) + + def __repr__(self) -> str: + return f'Signed={self.signed}' + +class TypeConstraintBitWidth(TypeConstraintBase): + __slots__ = ('minb', 'maxb', ) + + minb: int + maxb: int + + def __init__(self, *, minb: int = 1, maxb: int = 64) -> None: + assert minb is not None or maxb is not None + assert maxb <= 64 # For now, support up to 64 bits values + + self.minb = minb + self.maxb = maxb + + def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': + if not isinstance(other, TypeConstraintBitWidth): + raise Exception('Invalid comparison') + + if self.minb > other.maxb: + raise TypingNarrowProtoError('Min bitwidth exceeds other max bitwidth') + + if other.minb > self.maxb: + raise TypingNarrowProtoError('Other min bitwidth exceeds max bitwidth') + + return TypeConstraintBitWidth( + minb=max(self.minb, other.minb), + maxb=min(self.maxb, other.maxb), + ) + + def __repr__(self) -> str: + return f'BitWidth={self.minb}..{self.maxb}' + +class TypeVar: + def __init__(self, ctx: 'Context') -> None: + self.context = ctx + self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {} + self.locations: List[str] = [] + + def add_constraint(self, newconst: TypeConstraintBase) -> None: + if newconst.__class__ in self.constraints: + self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst) + else: + self.constraints[newconst.__class__] = newconst + + def add_location(self, ref: str) -> None: + self.locations.append(ref) + + def __repr__(self) -> str: + return ( + 'TypeVar<' + + '; '.join(map(repr, self.constraints.values())) + + '; locations: ' + + ', '.join(self.locations) + + '>' + ) + +class Context: + def clone(self) -> 'Context': + return self # TODO: STUB + + def new_var(self) -> TypeVar: + return TypeVar(self) + + def unify(self, l: 'TypeVar', r: 'TypeVar') -> None: + newtypevar = self.new_var() + + try: + for const in l.constraints.values(): + newtypevar.add_constraint(const) + for const in r.constraints.values(): + newtypevar.add_constraint(const) + except TypingNarrowProtoError as ex: + raise TypingNarrowError(l, r, str(ex)) from None + + newtypevar.locations.extend(l.locations) + newtypevar.locations.extend(r.locations) + + # Make pointer locations to the constraints and locations + # so they get linked together throughout the unification + + l.constraints = newtypevar.constraints + l.locations = newtypevar.locations + + r.constraints = newtypevar.constraints + r.locations = newtypevar.locations + + return diff --git a/tests/integration/runners.py b/tests/integration/runners.py index fd3a53e..57d2adb 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, TextIO import ctypes import io +import warnings import pywasm.binary import wasm3 @@ -13,6 +14,7 @@ import wasmtime from phasm.compiler import phasm_compile from phasm.parser import phasm_parse +from phasm.typer import phasm_type from phasm import ourlang from phasm import wasm @@ -40,6 +42,10 @@ class RunnerBase: Parses the Phasm code into an AST """ self.phasm_ast = phasm_parse(self.phasm_code) + try: + phasm_type(self.phasm_ast) + except NotImplementedError as exc: + warnings.warn(f'phash_type throws an NotImplementedError on this test: {exc}') def compile_ast(self) -> None: """ -- 2.49.0 From 48e16c38b9f89b23ce99bba6005b9bcfe864c36d Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Fri, 16 Sep 2022 16:42:24 +0200 Subject: [PATCH 02/18] FunctionParam is a class, more framework stuff --- phasm/codestyle.py | 6 ++--- phasm/compiler.py | 4 ++-- phasm/exceptions.py | 3 +++ phasm/ourlang.py | 24 ++++++++++++++------ phasm/parser.py | 38 +++++++++++++++++-------------- phasm/typer.py | 43 ++++++++++++++++++++++++++++++------ phasm/typing.py | 10 +++------ tests/integration/runners.py | 2 +- 8 files changed, 87 insertions(+), 43 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 4b38e32..d0e9c70 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -104,7 +104,7 @@ def expression(inp: ourlang.Expression) -> str: ) + ', )' if isinstance(inp, ourlang.VariableReference): - return str(inp.name) + return str(inp.variable.name) if isinstance(inp, ourlang.UnaryOp): if ( @@ -193,8 +193,8 @@ def function(inp: ourlang.Function) -> str: result += '@imported\n' args = ', '.join( - f'{x}: {type_(y)}' - for x, y in inp.posonlyargs + f'{p.name}: {type_(p.type)}' + for p in inp.posonlyargs ) result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' diff --git a/phasm/compiler.py b/phasm/compiler.py index 4ba5d5a..d36561d 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -160,7 +160,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: return if isinstance(inp, ourlang.VariableReference): - wgn.add_statement('local.get', '${}'.format(inp.name)) + wgn.add_statement('local.get', '${}'.format(inp.variable.name)) return if isinstance(inp, ourlang.BinaryOp): @@ -450,7 +450,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param: """ Compile: function argument """ - return (inp[0], type_(inp[1]), ) + return (inp.name, type_(inp.type), ) def import_(inp: ourlang.Function) -> wasm.Import: """ diff --git a/phasm/exceptions.py b/phasm/exceptions.py index b459c22..abbcdd2 100644 --- a/phasm/exceptions.py +++ b/phasm/exceptions.py @@ -6,3 +6,6 @@ class StaticError(Exception): """ An error found during static analysis """ + +class TypingError(Exception): + pass diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 3338ce4..6ae9518 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -156,13 +156,13 @@ class VariableReference(Expression): """ An variable reference expression within a statement """ - __slots__ = ('name', ) + __slots__ = ('variable', ) - name: str + variable: 'FunctionParam' # also possibly local - def __init__(self, type_: TypeBase, name: str) -> None: + def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None: super().__init__(type_) - self.name = name + self.variable = variable class UnaryOp(Expression): """ @@ -352,7 +352,17 @@ class StatementIf(Statement): self.statements = [] self.else_statements = [] -FunctionParam = Tuple[str, TypeBase] +class FunctionParam: + __slots__ = ('name', 'type', 'type_var', ) + + name: str + type: TypeBase + type_var: Optional[TypeVar] + + def __init__(self, name: str, type_: TypeBase) -> None: + self.name = name + self.type = type_ + self.type_var = None class Function: """ @@ -394,7 +404,7 @@ class StructConstructor(Function): self.returns = struct for mem in struct.members: - self.posonlyargs.append((mem.name, mem.type, )) + self.posonlyargs.append(FunctionParam(mem.name, mem.type, )) self.struct = struct @@ -414,7 +424,7 @@ class TupleConstructor(Function): self.returns = tuple_ for mem in tuple_.members: - self.posonlyargs.append((f'arg{mem.idx}', mem.type, )) + self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, )) self.tuple = tuple_ diff --git a/phasm/parser.py b/phasm/parser.py index d95bfce..2d7a158 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -48,6 +48,7 @@ from .ourlang import ( Statement, StatementIf, StatementPass, StatementReturn, + FunctionParam, ModuleConstantDef, ) @@ -60,7 +61,7 @@ def phasm_parse(source: str) -> Module: our_visitor = OurVisitor() return our_visitor.visit_Module(res) -OurLocals = Dict[str, TypeBase] +OurLocals = Dict[str, Union[FunctionParam]] # Also local variable and module constants? class OurVisitor: """ @@ -141,7 +142,7 @@ class OurVisitor: if not arg.annotation: _raise_static_error(node, 'Type is required') - function.posonlyargs.append(( + function.posonlyargs.append(FunctionParam( arg.arg, self.visit_type(module, arg.annotation), )) @@ -297,7 +298,10 @@ class OurVisitor: def visit_Module_FunctionDef(self, module: Module, node: ast.FunctionDef) -> None: function = module.functions[node.name] - our_locals = dict(function.posonlyargs) + our_locals: OurLocals = { + x.name: x + for x in function.posonlyargs + } for stmt in node.body: function.statements.append( @@ -427,11 +431,11 @@ class OurVisitor: _raise_static_error(node, 'Must be load context') if node.id in our_locals: - act_type = our_locals[node.id] - if exp_type != act_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(act_type)}') + param = our_locals[node.id] + if exp_type != param.type: + _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}') - return VariableReference(act_type, node.id) + return VariableReference(param.type, param) if node.id in module.constant_defs: cdef = module.constant_defs[node.id] @@ -541,10 +545,10 @@ class OurVisitor: if exp_type.__class__ != func.returns.__class__: _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') - if func.returns.__class__ != func.posonlyargs[0][1].__class__: - _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0][1])}') + if func.returns.__class__ != func.posonlyargs[0].type.__class__: + _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}') - if module.types['u8'].__class__ != func.posonlyargs[1][1].__class__: + if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__: _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') return Fold( @@ -568,8 +572,8 @@ class OurVisitor: result = FunctionCall(func) result.arguments.extend( - self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_type, arg_expr) - for arg_expr, (_, arg_type) in zip(node.args, func.posonlyargs) + self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr) + for arg_expr, param in zip(node.args, func.posonlyargs) ) return result @@ -586,7 +590,9 @@ class OurVisitor: if not node.value.id in our_locals: _raise_static_error(node, f'Undefined variable {node.value.id}') - node_typ = our_locals[node.value.id] + param = our_locals[node.value.id] + + node_typ = param.type if not isinstance(node_typ, TypeStruct): _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') @@ -598,7 +604,7 @@ class OurVisitor: _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') return AccessStructMember( - VariableReference(node_typ, node.value.id), + VariableReference(node_typ, param), member, ) @@ -614,8 +620,8 @@ class OurVisitor: varref: Union[ModuleConstantReference, VariableReference] if node.value.id in our_locals: - node_typ = our_locals[node.value.id] - varref = VariableReference(node_typ, node.value.id) + param = our_locals[node.value.id] + varref = VariableReference(param.type, param) elif node.value.id in module.constant_defs: constant_def = module.constant_defs[node.value.id] node_typ = constant_def.type diff --git a/phasm/typer.py b/phasm/typer.py index d21f823..c0677fb 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -1,8 +1,6 @@ """ Type checks and enriches the given ast """ -from math import ceil, log2 - from . import ourlang from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar @@ -32,7 +30,14 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.Constant): return constant(ctx, inp) + if isinstance(inp, ourlang.VariableReference): + assert inp.variable.type_var is not None, inp + return inp.variable.type_var + if isinstance(inp, ourlang.BinaryOp): + if inp.operator not in ('+', '-', '|', '&', '^'): + raise NotImplementedError(expression, inp, inp.operator) + left = expression(ctx, inp.left) right = expression(ctx, inp.right) ctx.unify(left, right) @@ -41,11 +46,11 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': raise NotImplementedError(expression, inp) def function(ctx: 'Context', inp: ourlang.Function) -> None: - bctx = ctx.clone() # Clone whenever we go into a block + for param in inp.posonlyargs: + param.type_var = _convert_old_type(ctx, param.type) - assert len(inp.statements) == 1 # TODO - - assert isinstance(inp.statements[0], ourlang.StatementReturn) + if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn): + raise NotImplementedError('Functions with not just a return statement') typ = expression(ctx, inp.statements[0].value) ctx.unify(_convert_old_type(ctx, inp.returns), typ) @@ -62,10 +67,34 @@ from . import typing def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar: result = ctx.new_var() + if isinstance(inp, typing.TypeUInt8): + result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u8') + return result + if isinstance(inp, typing.TypeUInt32): - result.add_constraint(TypeConstraintBitWidth(maxb=32)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(False)) result.add_location('u32') return result + if isinstance(inp, typing.TypeUInt64): + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u64') + return result + + if isinstance(inp, typing.TypeInt32): + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location('i32') + return result + + if isinstance(inp, typing.TypeInt64): + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location('i64') + return result + raise NotImplementedError(_convert_old_type, inp) diff --git a/phasm/typing.py b/phasm/typing.py index 0cb213d..65bb095 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -3,6 +3,8 @@ The phasm type system """ from typing import Dict, Optional, List, Type +from .exceptions import TypingError + class TypeBase: """ TypeBase base class @@ -203,9 +205,6 @@ class TypeStruct(TypeBase): ## NEW STUFF BELOW -class TypingError(Exception): - pass - class TypingNarrowProtoError(TypingError): pass @@ -237,7 +236,7 @@ class TypeConstraintSigned(TypeConstraintBase): return TypeConstraintSigned(other.signed) if self.signed is not other.signed: - raise TypeError() + raise TypingNarrowProtoError('Signed does not match') return TypeConstraintSigned(self.signed) @@ -300,9 +299,6 @@ class TypeVar: ) class Context: - def clone(self) -> 'Context': - return self # TODO: STUB - def new_var(self) -> TypeVar: return TypeVar(self) diff --git a/tests/integration/runners.py b/tests/integration/runners.py index 57d2adb..005d44e 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -45,7 +45,7 @@ class RunnerBase: try: phasm_type(self.phasm_ast) except NotImplementedError as exc: - warnings.warn(f'phash_type throws an NotImplementedError on this test: {exc}') + warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}') def compile_ast(self) -> None: """ -- 2.49.0 From 7669f3cbca3c67ca124e777c9fbb6486ea91d66b Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Fri, 16 Sep 2022 17:01:23 +0200 Subject: [PATCH 03/18] More framework stuff --- phasm/ourlang.py | 4 +- phasm/parser.py | 54 +++++++++++++-------------- phasm/typer.py | 43 +++++++++++++-------- tests/integration/test_simple.py | 15 ++++++++ tests/integration/test_type_checks.py | 31 +++++++++++++++ 5 files changed, 103 insertions(+), 44 deletions(-) create mode 100644 tests/integration/test_type_checks.py diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 6ae9518..efc1a66 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -368,7 +368,7 @@ class Function: """ A function processes input and produces output """ - __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'posonlyargs', ) + __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', ) name: str lineno: int @@ -376,6 +376,7 @@ class Function: imported: bool statements: List[Statement] returns: TypeBase + returns_type_var: Optional[TypeVar] posonlyargs: List[FunctionParam] def __init__(self, name: str, lineno: int) -> None: @@ -385,6 +386,7 @@ class Function: self.imported = False self.statements = [] self.returns = TypeNone() + self.returns_type_var = None self.posonlyargs = [] class StructConstructor(Function): diff --git a/phasm/parser.py b/phasm/parser.py index 2d7a158..78486ef 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -564,8 +564,8 @@ class OurVisitor: func = module.functions[node.func.id] - if func.returns != exp_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') + # if func.returns != exp_type: + # _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') if len(func.posonlyargs) != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') @@ -700,47 +700,47 @@ class OurVisitor: _not_implemented(node.kind is None, 'Constant.kind') if isinstance(exp_type, TypeUInt8): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 255: - _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < 0 or node.value > 255: + # _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') return ConstantUInt8(exp_type, node.value) if isinstance(exp_type, TypeUInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 4294967295: - _raise_static_error(node, 'Integer value out of range') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < 0 or node.value > 4294967295: + # _raise_static_error(node, 'Integer value out of range') return ConstantUInt32(exp_type, node.value) if isinstance(exp_type, TypeUInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < 0 or node.value > 18446744073709551615: - _raise_static_error(node, 'Integer value out of range') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < 0 or node.value > 18446744073709551615: + # _raise_static_error(node, 'Integer value out of range') return ConstantUInt64(exp_type, node.value) if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < -2147483648 or node.value > 2147483647: - _raise_static_error(node, 'Integer value out of range') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < -2147483648 or node.value > 2147483647: + # _raise_static_error(node, 'Integer value out of range') return ConstantInt32(exp_type, node.value) if isinstance(exp_type, TypeInt64): - if not isinstance(node.value, int): - _raise_static_error(node, 'Expected integer value') - - if node.value < -9223372036854775808 or node.value > 9223372036854775807: - _raise_static_error(node, 'Integer value out of range') + # if not isinstance(node.value, int): + # _raise_static_error(node, 'Expected integer value') + # + # if node.value < -9223372036854775808 or node.value > 9223372036854775807: + # _raise_static_error(node, 'Integer value out of range') return ConstantInt64(exp_type, node.value) diff --git a/phasm/typer.py b/phasm/typer.py index c0677fb..97c5e44 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -9,22 +9,23 @@ def phasm_type(inp: ourlang.Module) -> None: module(inp) def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': - if getattr(inp, 'value', int): + value = getattr(inp, 'value', None) + if isinstance(value, int): result = ctx.new_var() # Need at least this many bits to store this constant value - result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # type: ignore + result.add_constraint(TypeConstraintBitWidth(minb=len(bin(value)) - 2)) # Don't dictate anything about signedness - you can use a signed # constant in an unsigned variable if the bits fit result.add_constraint(TypeConstraintSigned(None)) - result.add_location(str(inp.value)) # type: ignore + result.add_location(str(value)) inp.type_var = result return result - raise NotImplementedError(constant, inp) + raise NotImplementedError(constant, inp, value) def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.Constant): @@ -43,58 +44,68 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': ctx.unify(left, right) return left + if isinstance(inp, ourlang.FunctionCall): + assert inp.function.returns_type_var is not None + if inp.function.posonlyargs: + raise NotImplementedError + + return inp.function.returns_type_var + raise NotImplementedError(expression, inp) def function(ctx: 'Context', inp: ourlang.Function) -> None: - for param in inp.posonlyargs: - param.type_var = _convert_old_type(ctx, param.type) - if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn): raise NotImplementedError('Functions with not just a return statement') typ = expression(ctx, inp.statements[0].value) - ctx.unify(_convert_old_type(ctx, inp.returns), typ) + assert inp.returns_type_var is not None + ctx.unify(inp.returns_type_var, typ) return def module(inp: ourlang.Module) -> None: ctx = Context() + for func in inp.functions.values(): + func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)') + for param in func.posonlyargs: + param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}') + for func in inp.functions.values(): function(ctx, func) from . import typing -def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar: +def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar: result = ctx.new_var() if isinstance(inp, typing.TypeUInt8): result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location('u8') + result.add_location(location) return result if isinstance(inp, typing.TypeUInt32): result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location('u32') + result.add_location(location) return result if isinstance(inp, typing.TypeUInt64): - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location('u64') + result.add_location(location) return result if isinstance(inp, typing.TypeInt32): result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(True)) - result.add_location('i32') + result.add_location(location) return result if isinstance(inp, typing.TypeInt64): - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=64)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(True)) - result.add_location('i64') + result.add_location(location) return result raise NotImplementedError(_convert_old_type, inp) diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index f0c2993..449f34c 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -304,6 +304,21 @@ def testEntry(a: i32, b: i32) -> i32: assert 1 == suite.run_code(10, 20).returned_value assert 0 == suite.run_code(10, 10).returned_value +@pytest.mark.integration_test +def test_call_no_args(): + code_py = """ +def helper() -> i32: + return 19 + +@exported +def testEntry() -> i32: + return helper() +""" + + result = Suite(code_py).run_code() + + assert 19 == result.returned_value + @pytest.mark.integration_test def test_call_pre_defined(): code_py = """ diff --git a/tests/integration/test_type_checks.py b/tests/integration/test_type_checks.py new file mode 100644 index 0000000..afc1d1e --- /dev/null +++ b/tests/integration/test_type_checks.py @@ -0,0 +1,31 @@ +import pytest + +from phasm.parser import phasm_parse +from phasm.typer import phasm_type +from phasm.exceptions import TypingError + +@pytest.mark.integration_test +def test_constant_too_wide(): + code_py = """ +def func_const() -> u8: + return 0xFFF +""" + + ast = phasm_parse(code_py) + with pytest.raises(TypingError, match='Other min bitwidth exceeds max bitwidth'): + phasm_type(ast) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', [32, 64]) +def test_signed_mismatch(type_): + code_py = f""" +def func_const() -> u{type_}: + return 0 + +def func_call() -> i{type_}: + return func_const() +""" + + ast = phasm_parse(code_py) + with pytest.raises(TypingError, match='Signed does not matchq'): + phasm_type(ast) -- 2.49.0 From 2d0daf4b905fcbcb0608bc5e92c20d1bc53cb6da Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Fri, 16 Sep 2022 17:04:13 +0200 Subject: [PATCH 04/18] Fixes --- phasm/parser.py | 1 + tests/integration/test_type_checks.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/phasm/parser.py b/phasm/parser.py index 78486ef..2aa25fc 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -621,6 +621,7 @@ class OurVisitor: varref: Union[ModuleConstantReference, VariableReference] if node.value.id in our_locals: param = our_locals[node.value.id] + node_typ = param.type varref = VariableReference(param.type, param) elif node.value.id in module.constant_defs: constant_def = module.constant_defs[node.value.id] diff --git a/tests/integration/test_type_checks.py b/tests/integration/test_type_checks.py index afc1d1e..1389c2f 100644 --- a/tests/integration/test_type_checks.py +++ b/tests/integration/test_type_checks.py @@ -27,5 +27,5 @@ def func_call() -> i{type_}: """ ast = phasm_parse(code_py) - with pytest.raises(TypingError, match='Signed does not matchq'): + with pytest.raises(TypingError, match='Signed does not match'): phasm_type(ast) -- 2.49.0 From 6f3d9a5bcc6ada961b87334f492d65cc63da5508 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Fri, 16 Sep 2022 17:39:46 +0200 Subject: [PATCH 05/18] First attempt at ripping out old system This breaks test_addition[u32], which is a good thing to chase next. --- phasm/codestyle.py | 10 +---- phasm/compiler.py | 38 +++++++---------- phasm/ourlang.py | 93 +++++------------------------------------ phasm/parser.py | 100 ++++++++++----------------------------------- phasm/typer.py | 16 +++++--- phasm/typing.py | 51 +++++++++++++++++++++++ 6 files changed, 109 insertions(+), 199 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index d0e9c70..bbb2eff 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -86,14 +86,8 @@ def expression(inp: ourlang.Expression) -> str: """ Render: A Phasm expression """ - if isinstance(inp, ( - ourlang.ConstantUInt8, ourlang.ConstantUInt32, ourlang.ConstantUInt64, - ourlang.ConstantInt32, ourlang.ConstantInt64, - )): - return str(inp.value) - - if isinstance(inp, (ourlang.ConstantFloat32, ourlang.ConstantFloat64, )): - # These might not round trip if the original constant + if isinstance(inp, ourlang.ConstantPrimitive): + # Floats might not round trip if the original constant # could not fit in the given float type return str(inp.value) diff --git a/phasm/compiler.py b/phasm/compiler.py index d36561d..6505982 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -131,33 +131,25 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: """ Compile: Any expression """ - if isinstance(inp, ourlang.ConstantUInt8): - wgn.i32.const(inp.value) - return + if isinstance(inp, ourlang.ConstantPrimitive): + stp = typing.simplify(inp.type_var) + if stp is None: + raise NotImplementedError(f'Constants with type {inp.type_var}') - if isinstance(inp, ourlang.ConstantUInt32): - wgn.i32.const(inp.value) - return + if stp == 'u8': + # No native u8 type - treat as i32, with caution + wgn.i32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantUInt64): - wgn.i64.const(inp.value) - return + if stp in ('i32', 'u32'): + wgn.i32.const(inp.value) + return - if isinstance(inp, ourlang.ConstantInt32): - wgn.i32.const(inp.value) - return + if stp in ('i64', 'u64'): + wgn.i64.const(inp.value) + return - if isinstance(inp, ourlang.ConstantInt64): - wgn.i64.const(inp.value) - return - - if isinstance(inp, ourlang.ConstantFloat32): - wgn.f32.const(inp.value) - return - - if isinstance(inp, ourlang.ConstantFloat64): - wgn.f64.const(inp.value) - return + raise NotImplementedError(f'Constants with type {stp}') if isinstance(inp, ourlang.VariableReference): wgn.add_statement('local.get', '${}'.format(inp.variable.name)) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index efc1a66..6496a2e 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -44,88 +44,15 @@ class Constant(Expression): """ __slots__ = () -class ConstantUInt8(Constant): +class ConstantPrimitive(Constant): """ - An UInt8 constant value expression within a statement + An primitive constant value expression within a statement """ __slots__ = ('value', ) - value: int + value: Union[int, float] - def __init__(self, type_: TypeUInt8, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantUInt32(Constant): - """ - An UInt32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeUInt32, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantUInt64(Constant): - """ - An UInt64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeUInt64, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantInt32(Constant): - """ - An Int32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeInt32, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantInt64(Constant): - """ - An Int64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: int - - def __init__(self, type_: TypeInt64, value: int) -> None: - super().__init__(type_) - self.value = value - -class ConstantFloat32(Constant): - """ - An Float32 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: float - - def __init__(self, type_: TypeFloat32, value: float) -> None: - super().__init__(type_) - self.value = value - -class ConstantFloat64(Constant): - """ - An Float64 constant value expression within a statement - """ - __slots__ = ('value', ) - - value: float - - def __init__(self, type_: TypeFloat64, value: float) -> None: - super().__init__(type_) + def __init__(self, value: Union[int, float]) -> None: self.value = value class ConstantTuple(Constant): @@ -134,9 +61,9 @@ class ConstantTuple(Constant): """ __slots__ = ('value', ) - value: List[Constant] + value: List[ConstantPrimitive] - def __init__(self, type_: TypeTuple, value: List[Constant]) -> None: + def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? super().__init__(type_) self.value = value @@ -146,9 +73,9 @@ class ConstantStaticArray(Constant): """ __slots__ = ('value', ) - value: List[Constant] + value: List[ConstantPrimitive] - def __init__(self, type_: TypeStaticArray, value: List[Constant]) -> None: + def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? super().__init__(type_) self.value = value @@ -455,10 +382,10 @@ class ModuleDataBlock: """ __slots__ = ('data', 'address', ) - data: List[Constant] + data: List[ConstantPrimitive] address: Optional[int] - def __init__(self, data: List[Constant]) -> None: + def __init__(self, data: List[ConstantPrimitive]) -> None: self.data = data self.address = None diff --git a/phasm/parser.py b/phasm/parser.py index 2aa25fc..8b1b3a2 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -35,9 +35,7 @@ from .ourlang import ( AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, BinaryOp, Constant, - ConstantFloat32, ConstantFloat64, ConstantInt32, ConstantInt64, - ConstantUInt8, ConstantUInt32, ConstantUInt64, - ConstantTuple, ConstantStaticArray, + ConstantPrimitive, ConstantTuple, ConstantStaticArray, FunctionCall, StructConstructor, TupleConstructor, @@ -211,18 +209,14 @@ class OurVisitor: exp_type = self.visit_type(module, node.annotation) - if isinstance(exp_type, TypeInt32): - if not isinstance(node.value, ast.Constant): - _raise_static_error(node, 'Must be constant') - - constant = ModuleConstantDef( + if isinstance(node.value, ast.Constant): + return ModuleConstantDef( node.target.id, node.lineno, exp_type, - self.visit_Module_Constant(module, exp_type, node.value), + self.visit_Module_Constant(module, node.value), None, ) - return constant if isinstance(exp_type, TypeTuple): if not isinstance(node.value, ast.Tuple): @@ -232,7 +226,7 @@ class OurVisitor: _raise_static_error(node, 'Invalid number of tuple values') tuple_data = [ - self.visit_Module_Constant(module, mem.type, arg_node) + self.visit_Module_Constant(module, arg_node) for arg_node, mem in zip(node.value.elts, exp_type.members) if isinstance(arg_node, ast.Constant) ] @@ -260,7 +254,7 @@ class OurVisitor: _raise_static_error(node, 'Invalid number of static array values') static_array_data = [ - self.visit_Module_Constant(module, exp_type.member_type, arg_node) + self.visit_Module_Constant(module, arg_node) for arg_node in node.value.elts if isinstance(arg_node, ast.Constant) ] @@ -413,7 +407,7 @@ class OurVisitor: if isinstance(node, ast.Constant): return self.visit_Module_Constant( - module, exp_type, node, + module, node, ) if isinstance(node, ast.Attribute): @@ -649,12 +643,15 @@ class OurVisitor: ) if isinstance(node_typ, TypeTuple): - if not isinstance(slice_expr, ConstantUInt32): + if not isinstance(slice_expr, ConstantPrimitive): _raise_static_error(node, 'Must subscript using a constant index') idx = slice_expr.value - if len(node_typ.members) <= idx: + if not isinstance(idx, int): + _raise_static_error(node, 'Must subscript using a constant integer index') + + if not (0 <= idx < len(node_typ.members)): _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') tuple_member = node_typ.members[idx] @@ -673,7 +670,7 @@ class OurVisitor: if exp_type != node_typ.member_type: _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}') - if not isinstance(slice_expr, ConstantInt32): + if not isinstance(slice_expr, ConstantPrimitive): return AccessStaticArrayMember( varref, node_typ, @@ -682,7 +679,10 @@ class OurVisitor: idx = slice_expr.value - if len(node_typ.members) <= idx: + if not isinstance(idx, int): + _raise_static_error(node, 'Must subscript using an integer index') + + if not (0 <= idx < len(node_typ.members)): _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') static_array_member = node_typ.members[idx] @@ -695,73 +695,15 @@ class OurVisitor: _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') - def visit_Module_Constant(self, module: Module, exp_type: TypeBase, node: ast.Constant) -> Constant: + def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive: del module _not_implemented(node.kind is None, 'Constant.kind') - if isinstance(exp_type, TypeUInt8): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < 0 or node.value > 255: - # _raise_static_error(node, f'Integer value out of range; expected 0..255, actual {node.value}') + if isinstance(node.value, (int, float, )): + return ConstantPrimitive(node.value) - return ConstantUInt8(exp_type, node.value) - - if isinstance(exp_type, TypeUInt32): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < 0 or node.value > 4294967295: - # _raise_static_error(node, 'Integer value out of range') - - return ConstantUInt32(exp_type, node.value) - - if isinstance(exp_type, TypeUInt64): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < 0 or node.value > 18446744073709551615: - # _raise_static_error(node, 'Integer value out of range') - - return ConstantUInt64(exp_type, node.value) - - if isinstance(exp_type, TypeInt32): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < -2147483648 or node.value > 2147483647: - # _raise_static_error(node, 'Integer value out of range') - - return ConstantInt32(exp_type, node.value) - - if isinstance(exp_type, TypeInt64): - # if not isinstance(node.value, int): - # _raise_static_error(node, 'Expected integer value') - # - # if node.value < -9223372036854775808 or node.value > 9223372036854775807: - # _raise_static_error(node, 'Integer value out of range') - - return ConstantInt64(exp_type, node.value) - - if isinstance(exp_type, TypeFloat32): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat32(exp_type, node.value) - - if isinstance(exp_type, TypeFloat64): - if not isinstance(node.value, (float, int, )): - _raise_static_error(node, 'Expected float value') - - # FIXME: Range check - - return ConstantFloat64(exp_type, node.value) - - raise NotImplementedError(f'{node} as const for type {exp_type}') + raise NotImplementedError(f'{node.value} as constant') def visit_type(self, module: Module, node: ast.expr) -> TypeBase: if isinstance(node, ast.Constant): diff --git a/phasm/typer.py b/phasm/typer.py index 97c5e44..e835c96 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -3,29 +3,33 @@ Type checks and enriches the given ast """ from . import ourlang -from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar +from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar def phasm_type(inp: ourlang.Module) -> None: module(inp) def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': - value = getattr(inp, 'value', None) - if isinstance(value, int): + if isinstance(inp, ourlang.ConstantPrimitive): result = ctx.new_var() + if not isinstance(inp.value, int): + raise NotImplementedError('Float constants in new type system') + + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + # Need at least this many bits to store this constant value - result.add_constraint(TypeConstraintBitWidth(minb=len(bin(value)) - 2)) + result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # Don't dictate anything about signedness - you can use a signed # constant in an unsigned variable if the bits fit result.add_constraint(TypeConstraintSigned(None)) - result.add_location(str(value)) + result.add_location(str(inp.value)) inp.type_var = result return result - raise NotImplementedError(constant, inp, value) + raise NotImplementedError(constant, inp) def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.Constant): diff --git a/phasm/typing.py b/phasm/typing.py index 65bb095..72a5827 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -3,6 +3,8 @@ The phasm type system """ from typing import Dict, Optional, List, Type +import enum + from .exceptions import TypingError class TypeBase: @@ -218,6 +220,30 @@ class TypeConstraintBase: def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': raise NotImplementedError('narrow', self, other) +class TypeConstraintPrimitive(TypeConstraintBase): + __slots__ = ('primitive', ) + + class Primitive(enum.Enum): + INT = 0 + FLOAT = 1 + + primitive: Primitive + + def __init__(self, primitive: Primitive) -> None: + self.primitive = primitive + + def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive': + if not isinstance(other, TypeConstraintPrimitive): + raise Exception('Invalid comparison') + + if self.primitive != other.primitive: + raise TypingNarrowProtoError('Primitive does not match') + + return TypeConstraintPrimitive(self.primitive) + + def __repr__(self) -> str: + return f'Primitive={self.primitive.name}' + class TypeConstraintSigned(TypeConstraintBase): __slots__ = ('signed', ) @@ -326,3 +352,28 @@ class Context: r.locations = newtypevar.locations return + +def simplify(inp: TypeVar) -> Optional[str]: + tc_prim = inp.constraints.get(TypeConstraintPrimitive) + tc_bits = inp.constraints.get(TypeConstraintBitWidth) + tc_sign = inp.constraints.get(TypeConstraintSigned) + + if tc_prim is None: + return None + + assert isinstance(tc_prim, TypeConstraintPrimitive) # type hint + primitive = tc_prim.primitive + if primitive is TypeConstraintPrimitive.Primitive.INT: + if tc_bits is None or tc_sign is None: + return None + + assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint + assert isinstance(tc_sign, TypeConstraintSigned) # type hint + + if tc_sign.signed is None or tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (8, 32, 64): + return None + + base = 'i' if tc_sign.signed else 'u' + return f'{base}{tc_bits.minb}' + + return None -- 2.49.0 From b2816164f98b3be5ba1157a1971823d98968b328 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 17 Sep 2022 17:14:17 +0200 Subject: [PATCH 06/18] Improved unification --- phasm/compiler.py | 8 ++- phasm/ourlang.py | 4 +- phasm/typer.py | 15 +++- phasm/typing.py | 105 ++++++++++++++++++++-------- pylintrc | 2 +- tests/integration/test_constants.py | 16 ++++- 6 files changed, 114 insertions(+), 36 deletions(-) diff --git a/phasm/compiler.py b/phasm/compiler.py index 6505982..e3282bc 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -132,20 +132,25 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: Compile: Any expression """ if isinstance(inp, ourlang.ConstantPrimitive): + assert inp.type_var is not None + stp = typing.simplify(inp.type_var) if stp is None: raise NotImplementedError(f'Constants with type {inp.type_var}') if stp == 'u8': # No native u8 type - treat as i32, with caution + assert isinstance(inp.value, int) wgn.i32.const(inp.value) return if stp in ('i32', 'u32'): + assert isinstance(inp.value, int) wgn.i32.const(inp.value) return if stp in ('i64', 'u64'): + assert isinstance(inp.value, int) wgn.i64.const(inp.value) return @@ -321,7 +326,8 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: """ Compile: Fold expression """ - mtyp = LOAD_STORE_TYPE_MAP.get(inp.base.type.__class__) + assert inp.base.type_var is not None + mtyp = typing.simplify(inp.base.type_var) if mtyp is None: # In the future might extend this by having structs or tuples # as members of struct or tuples diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 6496a2e..30a2527 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -361,11 +361,12 @@ class ModuleConstantDef: """ A constant definition within a module """ - __slots__ = ('name', 'lineno', 'type', 'constant', 'data_block', ) + __slots__ = ('name', 'lineno', 'type', 'type_var', 'constant', 'data_block', ) name: str lineno: int type: TypeBase + type_var: Optional[TypeVar] constant: Constant data_block: Optional['ModuleDataBlock'] @@ -373,6 +374,7 @@ class ModuleConstantDef: self.name = name self.lineno = lineno self.type = type_ + self.type_var = None self.constant = constant self.data_block = data_block diff --git a/phasm/typer.py b/phasm/typer.py index e835c96..169a774 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -40,7 +40,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': return inp.variable.type_var if isinstance(inp, ourlang.BinaryOp): - if inp.operator not in ('+', '-', '|', '&', '^'): + if inp.operator not in ('+', '-', '*', '|', '&', '^'): raise NotImplementedError(expression, inp, inp.operator) left = expression(ctx, inp.left) @@ -55,6 +55,10 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': return inp.function.returns_type_var + if isinstance(inp, ourlang.ModuleConstantReference): + assert inp.definition.type_var is not None + return inp.definition.type_var + raise NotImplementedError(expression, inp) def function(ctx: 'Context', inp: ourlang.Function) -> None: @@ -64,7 +68,11 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None: assert inp.returns_type_var is not None ctx.unify(inp.returns_type_var, typ) - return + +def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None: + inp.type_var = _convert_old_type(ctx, inp.type, inp.name) + constant(ctx, inp.constant) + ctx.unify(inp.type_var, inp.constant.type_var) def module(inp: ourlang.Module) -> None: ctx = Context() @@ -74,6 +82,9 @@ def module(inp: ourlang.Module) -> None: for param in func.posonlyargs: param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}') + for cdef in inp.constant_defs.values(): + module_constant_def(ctx, cdef) + for func in inp.functions.values(): function(ctx, func) diff --git a/phasm/typing.py b/phasm/typing.py index 72a5827..ffd8e16 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -301,62 +301,107 @@ class TypeConstraintBitWidth(TypeConstraintBase): return f'BitWidth={self.minb}..{self.maxb}' class TypeVar: - def __init__(self, ctx: 'Context') -> None: - self.context = ctx - self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {} - self.locations: List[str] = [] + __slots__ = ('ctx', 'ctx_id', ) + + ctx: 'Context' + ctx_id: int + + def __init__(self, ctx: 'Context', ctx_id: int) -> None: + self.ctx = ctx + self.ctx_id = ctx_id def add_constraint(self, newconst: TypeConstraintBase) -> None: - if newconst.__class__ in self.constraints: - self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst) + csts = self.ctx.var_constraints[self.ctx_id] + + if newconst.__class__ in csts: + csts[newconst.__class__] = csts[newconst.__class__].narrow(newconst) else: - self.constraints[newconst.__class__] = newconst + csts[newconst.__class__] = newconst def add_location(self, ref: str) -> None: - self.locations.append(ref) + self.ctx.var_locations[self.ctx_id].append(ref) def __repr__(self) -> str: return ( 'TypeVar<' - + '; '.join(map(repr, self.constraints.values())) + + '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values())) + '; locations: ' - + ', '.join(self.locations) + + ', '.join(self.ctx.var_locations[self.ctx_id]) + '>' ) class Context: + def __init__(self) -> None: + # Variables are unified (or entangled, if you will) + # that means that each TypeVar within a context has an ID, + # and all TypeVars with the same ID are the same TypeVar, + # even if they are a different instance + self.next_ctx_id = 1 + self.vars_by_id: Dict[int, List[TypeVar]] = {} + + # Store the TypeVar properties as a lookup + # so we can update these when unifying + self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {} + self.var_locations: Dict[int, List[str]] = {} + def new_var(self) -> TypeVar: - return TypeVar(self) + ctx_id = self.next_ctx_id + self.next_ctx_id += 1 + + result = TypeVar(self, ctx_id) + + self.vars_by_id[ctx_id] = [result] + self.var_constraints[ctx_id] = {} + self.var_locations[ctx_id] = [] + + return result def unify(self, l: 'TypeVar', r: 'TypeVar') -> None: - newtypevar = self.new_var() + assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return + + # Backup some values that we'll overwrite + l_ctx_id = l.ctx_id + r_ctx_id = r.ctx_id + l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id] + + # Create a new TypeVar, with the combined contraints + # and locations of the old ones + n = self.new_var() try: - for const in l.constraints.values(): - newtypevar.add_constraint(const) - for const in r.constraints.values(): - newtypevar.add_constraint(const) - except TypingNarrowProtoError as ex: - raise TypingNarrowError(l, r, str(ex)) from None + for const in self.var_constraints[l_ctx_id].values(): + n.add_constraint(const) + for const in self.var_constraints[r_ctx_id].values(): + n.add_constraint(const) + except TypingNarrowProtoError as exc: + raise TypingNarrowError(l, r, str(exc)) from None - newtypevar.locations.extend(l.locations) - newtypevar.locations.extend(r.locations) + self.var_locations[n.ctx_id].extend(self.var_locations[l_ctx_id]) + self.var_locations[n.ctx_id].extend(self.var_locations[r_ctx_id]) - # Make pointer locations to the constraints and locations - # so they get linked together throughout the unification + # ## + # And unify (or entangle) the old ones - l.constraints = newtypevar.constraints - l.locations = newtypevar.locations + # First update the IDs, so they all point to the new list + for type_var in l_r_var_list: + type_var.ctx_id = n.ctx_id - r.constraints = newtypevar.constraints - r.locations = newtypevar.locations + # Update our registry of TypeVars by ID, so we can find them + # on the next unify + self.vars_by_id[n.ctx_id].extend(l_r_var_list) - return + # Then delete the old values for the now gone variables + # Do this last, so exceptions thrown in the code above + # still have a valid context + del self.var_constraints[l_ctx_id] + del self.var_constraints[r_ctx_id] + del self.var_locations[l_ctx_id] + del self.var_locations[r_ctx_id] def simplify(inp: TypeVar) -> Optional[str]: - tc_prim = inp.constraints.get(TypeConstraintPrimitive) - tc_bits = inp.constraints.get(TypeConstraintBitWidth) - tc_sign = inp.constraints.get(TypeConstraintSigned) + tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) + tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) + tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned) if tc_prim is None: return None diff --git a/pylintrc b/pylintrc index 0591be3..3759e6b 100644 --- a/pylintrc +++ b/pylintrc @@ -1,5 +1,5 @@ [MASTER] -disable=C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 +disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 max-line-length=180 diff --git a/tests/integration/test_constants.py b/tests/integration/test_constants.py index 19f0203..accf9e2 100644 --- a/tests/integration/test_constants.py +++ b/tests/integration/test_constants.py @@ -3,7 +3,21 @@ import pytest from .helpers import Suite @pytest.mark.integration_test -def test_i32(): +def test_i32_asis(): + code_py = """ +CONSTANT: i32 = 13 + +@exported +def testEntry() -> i32: + return CONSTANT +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +def test_i32_binop(): code_py = """ CONSTANT: i32 = 13 -- 2.49.0 From 4b46483895470904db56e02d5063a5dbd56f5e30 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 17 Sep 2022 19:21:56 +0200 Subject: [PATCH 07/18] Worked on floats --- phasm/compiler.py | 10 +++++ phasm/typer.py | 69 ++++++++++++++++++++++++++------ phasm/typing.py | 11 +++++ tests/integration/test_simple.py | 68 ++++++++++++++++++++++++------- 4 files changed, 132 insertions(+), 26 deletions(-) diff --git a/phasm/compiler.py b/phasm/compiler.py index e3282bc..af6ae58 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -154,6 +154,16 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.i64.const(inp.value) return + if stp == 'f32': + assert isinstance(inp.value, float) + wgn.f32.const(inp.value) + return + + if stp == 'f64': + assert isinstance(inp.value, float) + wgn.f64.const(inp.value) + return + raise NotImplementedError(f'Constants with type {stp}') if isinstance(inp, ourlang.VariableReference): diff --git a/phasm/typer.py b/phasm/typer.py index 169a774..07f7f33 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -12,22 +12,48 @@ def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': if isinstance(inp, ourlang.ConstantPrimitive): result = ctx.new_var() - if not isinstance(inp.value, int): - raise NotImplementedError('Float constants in new type system') + if isinstance(inp.value, int): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + # Need at least this many bits to store this constant value + result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) + # Don't dictate anything about signedness - you can use a signed + # constant in an unsigned variable if the bits fit + result.add_constraint(TypeConstraintSigned(None)) - # Need at least this many bits to store this constant value - result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) - # Don't dictate anything about signedness - you can use a signed - # constant in an unsigned variable if the bits fit - result.add_constraint(TypeConstraintSigned(None)) + result.add_location(str(inp.value)) - result.add_location(str(inp.value)) + inp.type_var = result - inp.type_var = result + return result - return result + if isinstance(inp.value, float): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + + # We don't have fancy logic here to detect if the float constant + # fits in the given type. There a number of edge cases to consider, + # before implementing this. + + # 1) It may fit anyhow + # e.g., if the user has 3.14 as a float constant, neither a + # f32 nor a f64 can really fit this value. But does that mean + # we should throw an error? + + # If we'd implement it, we'd want to convert it to hex using + # inp.value.hex(), which would give us the mantissa and exponent. + # We can use those to determine what bit size the value should be in. + + # If that doesn't work out, we'd need another way to calculate the + # difference between what was written and what actually gets stored + # in memory, and warn if the difference is beyond a treshold. + + result.add_location(str(inp.value)) + + inp.type_var = result + + return result + + raise NotImplementedError(constant, inp, inp.value) raise NotImplementedError(constant, inp) @@ -39,6 +65,13 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': assert inp.variable.type_var is not None, inp return inp.variable.type_var + if isinstance(inp, ourlang.UnaryOp): + if inp.operator not in ('sqrt', ): + raise NotImplementedError(expression, inp, inp.operator) + + right = expression(ctx, inp.right) + return right + if isinstance(inp, ourlang.BinaryOp): if inp.operator not in ('+', '-', '*', '|', '&', '^'): raise NotImplementedError(expression, inp, inp.operator) @@ -51,7 +84,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.FunctionCall): assert inp.function.returns_type_var is not None if inp.function.posonlyargs: - raise NotImplementedError + raise NotImplementedError('TODO: Functions with arguments') return inp.function.returns_type_var @@ -123,4 +156,16 @@ def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> Type result.add_location(location) return result + if isinstance(inp, typing.TypeFloat32): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_location(location) + return result + + if isinstance(inp, typing.TypeFloat64): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_location(location) + return result + raise NotImplementedError(_convert_old_type, inp) diff --git a/phasm/typing.py b/phasm/typing.py index ffd8e16..0a8982b 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -421,4 +421,15 @@ def simplify(inp: TypeVar) -> Optional[str]: base = 'i' if tc_sign.signed else 'u' return f'{base}{tc_bits.minb}' + if primitive is TypeConstraintPrimitive.Primitive.FLOAT: + if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint + return None + + assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint + + if tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (32, 64): + return None + + return f'f{tc_bits.minb}' + return None diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index 449f34c..aace8a7 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -2,14 +2,12 @@ import pytest from .helpers import Suite +ALL_INT_TYPES = ['u8', 'u32', 'u64', 'i32', 'i64'] +ALL_FLOAT_TYPES = ['f32', 'f64'] + TYPE_MAP = { - 'u8': int, - 'u32': int, - 'u64': int, - 'i32': int, - 'i64': int, - 'f32': float, - 'f64': float, + **{x: int for x in ALL_INT_TYPES}, + **{x: float for x in ALL_FLOAT_TYPES}, } COMPLETE_SIMPLE_TYPES = [ @@ -19,8 +17,8 @@ COMPLETE_SIMPLE_TYPES = [ ] @pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_return(type_): +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_return_int(type_): code_py = f""" @exported def testEntry() -> {type_}: @@ -33,8 +31,22 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_addition(type_): +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_return_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 32.125 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_addition_int(type_): code_py = f""" @exported def testEntry() -> {type_}: @@ -47,8 +59,22 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_subtraction(type_): +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_addition_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 32.0 + 0.125 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_subtraction_int(type_): code_py = f""" @exported def testEntry() -> {type_}: @@ -60,6 +86,20 @@ def testEntry() -> {type_}: assert 7 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_subtraction(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 100.0 - 67.875 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation def test_logical_left_shift(type_): @@ -136,7 +176,7 @@ def test_buildins_sqrt(type_): code_py = f""" @exported def testEntry() -> {type_}: - return sqrt(25) + return sqrt(25.0) """ result = Suite(code_py).run_code() -- 2.49.0 From 58f74d3e1d9c7dd41742242d0b0911abb4df3c74 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 17 Sep 2022 19:31:43 +0200 Subject: [PATCH 08/18] Restored function calling --- phasm/typer.py | 10 ++++++++-- tests/integration/test_simple.py | 26 ++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/phasm/typer.py b/phasm/typer.py index 07f7f33..de55ef3 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -83,8 +83,12 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.FunctionCall): assert inp.function.returns_type_var is not None - if inp.function.posonlyargs: - raise NotImplementedError('TODO: Functions with arguments') + + for param, expr in zip(inp.function.posonlyargs, inp.arguments): + assert param.type_var is not None + + arg = expression(ctx, expr) + ctx.unify(param.type_var, arg) return inp.function.returns_type_var @@ -105,6 +109,8 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None: def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None: inp.type_var = _convert_old_type(ctx, inp.type, inp.name) constant(ctx, inp.constant) + + assert inp.constant.type_var is not None ctx.unify(inp.type_var, inp.constant.type_var) def module(inp: ourlang.Module) -> None: diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py index aace8a7..fdad1e4 100644 --- a/tests/integration/test_simple.py +++ b/tests/integration/test_simple.py @@ -3,6 +3,7 @@ import pytest from .helpers import Suite ALL_INT_TYPES = ['u8', 'u32', 'u64', 'i32', 'i64'] +COMLETE_INT_TYPES = ['u32', 'u64', 'i32', 'i64'] ALL_FLOAT_TYPES = ['f32', 'f64'] TYPE_MAP = { @@ -45,7 +46,7 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ALL_INT_TYPES) +@pytest.mark.parametrize('type_', COMLETE_INT_TYPES) def test_addition_int(type_): code_py = f""" @exported @@ -73,7 +74,7 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ALL_INT_TYPES) +@pytest.mark.parametrize('type_', COMLETE_INT_TYPES) def test_subtraction_int(type_): code_py = f""" @exported @@ -390,8 +391,8 @@ def helper(left: i32, right: i32) -> i32: assert 7 == result.returned_value @pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_call_with_expression(type_): +@pytest.mark.parametrize('type_', COMLETE_INT_TYPES) +def test_call_with_expression_int(type_): code_py = f""" @exported def testEntry() -> {type_}: @@ -406,6 +407,23 @@ def helper(left: {type_}, right: {type_}) -> {type_}: assert 22 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_call_with_expression_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper(10.078125 + 90.046875, 63.0 + 5.0) + +def helper(left: {type_}, right: {type_}) -> {type_}: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + @pytest.mark.integration_test @pytest.mark.skip('Not yet implemented') def test_assign(): -- 2.49.0 From 564f00a419c05a5b05e42375dca6067652877adb Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 17 Sep 2022 20:13:16 +0200 Subject: [PATCH 09/18] Work on ripping out old type system --- phasm/codestyle.py | 69 +++++++-------------------- phasm/compiler.py | 108 ++++++++++++++++++++++++------------------ phasm/exceptions.py | 4 +- phasm/ourlang.py | 52 ++++++++++---------- phasm/parser.py | 69 ++++++--------------------- phasm/stdlib/alloc.py | 2 +- phasm/typer.py | 16 +++++++ phasm/typing.py | 42 +++++++++++++++- phasm/wasmeasy.py | 2 +- pylintrc | 2 +- 10 files changed, 184 insertions(+), 182 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index bbb2eff..334b3eb 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -3,7 +3,7 @@ This module generates source code based on the parsed AST It's intented to be a "any color, as long as it's black" kind of renderer """ -from typing import Generator +from typing import Generator, Optional from . import ourlang from . import typing @@ -16,55 +16,17 @@ def phasm_render(inp: ourlang.Module) -> str: Statements = Generator[str, None, None] -def type_(inp: typing.TypeBase) -> str: +def type_var(inp: Optional[typing.TypeVar]) -> str: """ - Render: Type (name) + Render: type's name """ - if isinstance(inp, typing.TypeNone): - return 'None' + assert inp is not None, typing.ASSERTION_ERROR - if isinstance(inp, typing.TypeBool): - return 'bool' + mtyp = typing.simplify(inp) + if mtyp is None: + raise NotImplementedError(f'Rendering type {inp}') - if isinstance(inp, typing.TypeUInt8): - return 'u8' - - if isinstance(inp, typing.TypeUInt32): - return 'u32' - - if isinstance(inp, typing.TypeUInt64): - return 'u64' - - if isinstance(inp, typing.TypeInt32): - return 'i32' - - if isinstance(inp, typing.TypeInt64): - return 'i64' - - if isinstance(inp, typing.TypeFloat32): - return 'f32' - - if isinstance(inp, typing.TypeFloat64): - return 'f64' - - if isinstance(inp, typing.TypeBytes): - return 'bytes' - - if isinstance(inp, typing.TypeTuple): - mems = ', '.join( - type_(x.type) - for x in inp.members - ) - - return f'({mems}, )' - - if isinstance(inp, typing.TypeStaticArray): - return f'{type_(inp.member_type)}[{len(inp.members)}]' - - if isinstance(inp, typing.TypeStruct): - return inp.name - - raise NotImplementedError(type_, inp) + return mtyp def struct_definition(inp: typing.TypeStruct) -> str: """ @@ -72,7 +34,8 @@ def struct_definition(inp: typing.TypeStruct) -> str: """ result = f'class {inp.name}:\n' for mem in inp.members: - result += f' {mem.name}: {type_(mem.type)}\n' + raise NotImplementedError('Structs broken after new type system') + # result += f' {mem.name}: {type_(mem.type)}\n' return result @@ -80,7 +43,7 @@ def constant_definition(inp: ourlang.ModuleConstantDef) -> str: """ Render: Module Constant's definition """ - return f'{inp.name}: {type_(inp.type)} = {expression(inp.constant)}\n' + return f'{inp.name}: {type_var(inp.type_var)} = {expression(inp.constant)}\n' def expression(inp: ourlang.Expression) -> str: """ @@ -107,7 +70,11 @@ def expression(inp: ourlang.Expression) -> str: return f'{inp.operator}({expression(inp.right)})' if inp.operator == 'cast': - return f'{type_(inp.type)}({expression(inp.right)})' + mtyp = type_var(inp.type_var) + if mtyp is None: + raise NotImplementedError(f'Casting to type {inp.type_var}') + + return f'{mtyp}({expression(inp.right)})' return f'{inp.operator}{expression(inp.right)}' @@ -187,11 +154,11 @@ def function(inp: ourlang.Function) -> str: result += '@imported\n' args = ', '.join( - f'{p.name}: {type_(p.type)}' + f'{p.name}: {type_var(p.type_var)}' for p in inp.posonlyargs ) - result += f'def {inp.name}({args}) -> {type_(inp.returns)}:\n' + result += f'def {inp.name}({args}) -> {type_var(inp.returns_type_var)}:\n' if inp.imported: result += ' pass\n' diff --git a/phasm/compiler.py b/phasm/compiler.py index af6ae58..edd9066 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -1,6 +1,8 @@ """ This module contains the code to convert parsed Ourlang into WebAssembly code """ +from typing import List + import struct from . import codestyle @@ -132,7 +134,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: Compile: Any expression """ if isinstance(inp, ourlang.ConstantPrimitive): - assert inp.type_var is not None + assert inp.type_var is not None, typing.ASSERTION_ERROR stp = typing.simplify(inp.type_var) if stp is None: @@ -174,73 +176,80 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: expression(wgn, inp.left) expression(wgn, inp.right) - if isinstance(inp.type, typing.TypeUInt8): + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) + + if mtyp == 'u8': if operator := U8_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeUInt32): + if mtyp == 'u32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := U32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeUInt64): + if mtyp == 'u64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := U64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if isinstance(inp.type, typing.TypeInt32): + if mtyp == 'i32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return if operator := I32_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i32.{operator}') return - if isinstance(inp.type, typing.TypeInt64): + if mtyp == 'i64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return if operator := I64_OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'i64.{operator}') return - if isinstance(inp.type, typing.TypeFloat32): + if mtyp == 'f32': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f32.{operator}') return - if isinstance(inp.type, typing.TypeFloat64): + if mtyp == 'f64': if operator := OPERATOR_MAP.get(inp.operator, None): wgn.add_statement(f'f64.{operator}') return - raise NotImplementedError(expression, inp.type, inp.operator) + raise NotImplementedError(expression, inp.type_var, inp.operator) if isinstance(inp, ourlang.UnaryOp): expression(wgn, inp.right) - if isinstance(inp.type, typing.TypeFloat32): + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) + + if mtyp == 'f32': if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: wgn.add_statement(f'f32.{inp.operator}') return - if isinstance(inp.type, typing.TypeFloat64): + if mtyp == 'f64': if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: wgn.add_statement(f'f64.{inp.operator}') return - if isinstance(inp.type, typing.TypeInt32): - if inp.operator == 'len': - if isinstance(inp.right.type, typing.TypeBytes): - wgn.i32.load() - return + # TODO: Broken after new type system + # if isinstance(inp.type, typing.TypeInt32): + # if inp.operator == 'len': + # if isinstance(inp.right.type, typing.TypeBytes): + # wgn.i32.load() + # return - if inp.operator == 'cast': - if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): - # Nothing to do, you can use an u8 value as a u32 no problem - return + # if inp.operator == 'cast': + # if isinstance(inp.type, typing.TypeUInt32) and isinstance(inp.right.type, typing.TypeUInt8): + # # Nothing to do, you can use an u8 value as a u32 no problem + # return - raise NotImplementedError(expression, inp.type, inp.operator) + raise NotImplementedError(expression, inp.type_var, inp.operator) if isinstance(inp, ourlang.FunctionCall): for arg in inp.arguments: @@ -249,14 +258,15 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.add_statement('call', '${}'.format(inp.function.name)) return - if isinstance(inp, ourlang.AccessBytesIndex): - if not isinstance(inp.type, typing.TypeUInt8): - raise NotImplementedError(inp, inp.type) - - expression(wgn, inp.varref) - expression(wgn, inp.index) - wgn.call(stdlib_types.__subscript_bytes__) - return + # TODO: Broken after new type system + # if isinstance(inp, ourlang.AccessBytesIndex): + # if not isinstance(inp.type, typing.TypeUInt8): + # raise NotImplementedError(inp, inp.type) + # + # expression(wgn, inp.varref) + # expression(wgn, inp.index) + # wgn.call(stdlib_types.__subscript_bytes__) + # return if isinstance(inp, ourlang.AccessStructMember): mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) @@ -305,27 +315,29 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: return if isinstance(inp, ourlang.ModuleConstantReference): - if isinstance(inp.type, typing.TypeTuple): - assert isinstance(inp.definition.constant, ourlang.ConstantTuple) - assert inp.definition.data_block is not None, 'Combined values are memory stored' - assert inp.definition.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.definition.data_block.address) - return - - if isinstance(inp.type, typing.TypeStaticArray): - assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) - assert inp.definition.data_block is not None, 'Combined values are memory stored' - assert inp.definition.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.definition.data_block.address) - return + # FIXME: Tuple / Static Array broken after new type system + # if isinstance(inp.type, typing.TypeTuple): + # assert isinstance(inp.definition.constant, ourlang.ConstantTuple) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return + # + # if isinstance(inp.type, typing.TypeStaticArray): + # assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return assert inp.definition.data_block is None, 'Primitives are not memory stored' - mtyp = LOAD_STORE_TYPE_MAP.get(inp.type.__class__) + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) if mtyp is None: # In the future might extend this by having structs or tuples # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.type) + raise NotImplementedError(expression, inp, inp.type_var) expression(wgn, inp.definition.constant) return @@ -336,13 +348,15 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: """ Compile: Fold expression """ - assert inp.base.type_var is not None + assert inp.base.type_var is not None, typing.ASSERTION_ERROR mtyp = typing.simplify(inp.base.type_var) if mtyp is None: # In the future might extend this by having structs or tuples # as members of struct or tuples raise NotImplementedError(expression, inp, inp.base) + raise NotImplementedError('TODO: Broken after new type system') + if inp.iter.type.__class__.__name__ != 'TypeBytes': raise NotImplementedError(expression, inp, inp.iter.type) @@ -563,7 +577,9 @@ def module_data(inp: ourlang.ModuleData) -> bytes: for block in inp.blocks: block.address = unalloc_ptr + 4 # 4 bytes for allocator header - data_list = [] + data_list: List[bytes] = [] + + raise NotImplementedError('Broken after new type system') for constant in block.data: if isinstance(constant, ourlang.ConstantUInt8): diff --git a/phasm/exceptions.py b/phasm/exceptions.py index abbcdd2..77c75e7 100644 --- a/phasm/exceptions.py +++ b/phasm/exceptions.py @@ -8,4 +8,6 @@ class StaticError(Exception): """ class TypingError(Exception): - pass + """ + An error found during the typing phase + """ diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 30a2527..ebc13d7 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -1,7 +1,7 @@ """ Contains the syntax tree for ourlang """ -from typing import Dict, List, Tuple, Optional, Union +from typing import Dict, List, Optional, Union import enum @@ -13,7 +13,6 @@ WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) from .typing import ( TypeBase, TypeNone, - TypeBool, TypeUInt8, TypeUInt32, TypeUInt64, TypeInt32, TypeInt64, TypeFloat32, TypeFloat64, @@ -29,13 +28,11 @@ class Expression: """ An expression within a statement """ - __slots__ = ('type', 'type_var', ) + __slots__ = ('type_var', ) - type: TypeBase type_var: Optional[TypeVar] - def __init__(self, type_: TypeBase) -> None: - self.type = type_ + def __init__(self) -> None: self.type_var = None class Constant(Expression): @@ -53,6 +50,7 @@ class ConstantPrimitive(Constant): value: Union[int, float] def __init__(self, value: Union[int, float]) -> None: + super().__init__() self.value = value class ConstantTuple(Constant): @@ -63,8 +61,8 @@ class ConstantTuple(Constant): value: List[ConstantPrimitive] - def __init__(self, type_: TypeTuple, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? - super().__init__(type_) + def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Tuple of tuples? + super().__init__() self.value = value class ConstantStaticArray(Constant): @@ -75,8 +73,8 @@ class ConstantStaticArray(Constant): value: List[ConstantPrimitive] - def __init__(self, type_: TypeStaticArray, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? - super().__init__(type_) + def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? + super().__init__() self.value = value class VariableReference(Expression): @@ -87,8 +85,8 @@ class VariableReference(Expression): variable: 'FunctionParam' # also possibly local - def __init__(self, type_: TypeBase, variable: 'FunctionParam') -> None: - super().__init__(type_) + def __init__(self, variable: 'FunctionParam') -> None: + super().__init__() self.variable = variable class UnaryOp(Expression): @@ -100,8 +98,8 @@ class UnaryOp(Expression): operator: str right: Expression - def __init__(self, type_: TypeBase, operator: str, right: Expression) -> None: - super().__init__(type_) + def __init__(self, operator: str, right: Expression) -> None: + super().__init__() self.operator = operator self.right = right @@ -116,8 +114,8 @@ class BinaryOp(Expression): left: Expression right: Expression - def __init__(self, type_: TypeBase, operator: str, left: Expression, right: Expression) -> None: - super().__init__(type_) + def __init__(self, operator: str, left: Expression, right: Expression) -> None: + super().__init__() self.operator = operator self.left = left @@ -133,7 +131,7 @@ class FunctionCall(Expression): arguments: List[Expression] def __init__(self, function: 'Function') -> None: - super().__init__(function.returns) + super().__init__() self.function = function self.arguments = [] @@ -147,8 +145,8 @@ class AccessBytesIndex(Expression): varref: VariableReference index: Expression - def __init__(self, type_: TypeBase, varref: VariableReference, index: Expression) -> None: - super().__init__(type_) + def __init__(self, varref: VariableReference, index: Expression) -> None: + super().__init__() self.varref = varref self.index = index @@ -163,7 +161,7 @@ class AccessStructMember(Expression): member: TypeStructMember def __init__(self, varref: VariableReference, member: TypeStructMember) -> None: - super().__init__(member.type) + super().__init__() self.varref = varref self.member = member @@ -178,7 +176,7 @@ class AccessTupleMember(Expression): member: TypeTupleMember def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None: - super().__init__(member.type) + super().__init__() self.varref = varref self.member = member @@ -194,7 +192,7 @@ class AccessStaticArrayMember(Expression): member: Union[Expression, TypeStaticArrayMember] def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None: - super().__init__(static_array.member_type) + super().__init__() self.varref = varref self.static_array = static_array @@ -218,13 +216,12 @@ class Fold(Expression): def __init__( self, - type_: TypeBase, dir_: Direction, func: 'Function', base: Expression, iter_: Expression, ) -> None: - super().__init__(type_) + super().__init__() self.dir = dir_ self.func = func @@ -239,8 +236,8 @@ class ModuleConstantReference(Expression): definition: 'ModuleConstantDef' - def __init__(self, type_: TypeBase, definition: 'ModuleConstantDef') -> None: - super().__init__(type_) + def __init__(self, definition: 'ModuleConstantDef') -> None: + super().__init__() self.definition = definition class Statement: @@ -280,6 +277,9 @@ class StatementIf(Statement): self.else_statements = [] class FunctionParam: + """ + A parameter for a Function + """ __slots__ = ('name', 'type', 'type_var', ) name: str diff --git a/phasm/parser.py b/phasm/parser.py index 8b1b3a2..d8e94af 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -7,13 +7,6 @@ import ast from .typing import ( TypeBase, - TypeUInt8, - TypeUInt32, - TypeUInt64, - TypeInt32, - TypeInt64, - TypeFloat32, - TypeFloat64, TypeBytes, TypeStruct, TypeStructMember, @@ -34,7 +27,6 @@ from .ourlang import ( Expression, AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, BinaryOp, - Constant, ConstantPrimitive, ConstantTuple, ConstantStaticArray, FunctionCall, @@ -242,7 +234,7 @@ class OurVisitor: node.target.id, node.lineno, exp_type, - ConstantTuple(exp_type, tuple_data), + ConstantTuple(tuple_data), data_block, ) @@ -270,7 +262,7 @@ class OurVisitor: node.target.id, node.lineno, exp_type, - ConstantStaticArray(exp_type, static_array_data), + ConstantStaticArray(static_array_data), data_block, ) @@ -359,7 +351,6 @@ class OurVisitor: # e.g. you can do `"hello" * 3` with the code below (yet) return BinaryOp( - exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), @@ -374,7 +365,6 @@ class OurVisitor: raise NotImplementedError(f'Operator {node.op}') return UnaryOp( - exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), ) @@ -396,7 +386,6 @@ class OurVisitor: # e.g. you can do `"hello" * 3` with the code below (yet) return BinaryOp( - exp_type, operator, self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), @@ -412,12 +401,12 @@ class OurVisitor: if isinstance(node, ast.Attribute): return self.visit_Module_FunctionDef_Attribute( - module, function, our_locals, exp_type, node, + module, function, our_locals, node, ) if isinstance(node, ast.Subscript): return self.visit_Module_FunctionDef_Subscript( - module, function, our_locals, exp_type, node, + module, function, our_locals, node, ) if isinstance(node, ast.Name): @@ -426,21 +415,17 @@ class OurVisitor: if node.id in our_locals: param = our_locals[node.id] - if exp_type != param.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(param.type)}') - - return VariableReference(param.type, param) + return VariableReference(param) if node.id in module.constant_defs: cdef = module.constant_defs[node.id] - if exp_type != cdef.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.id} is actually {codestyle.type_(cdef.type)}') - - return ModuleConstantReference(exp_type, cdef) + return ModuleConstantReference(cdef) _raise_static_error(node, f'Undefined variable {node.id}') if isinstance(node, ast.Tuple): + raise NotImplementedError('TODO: Broken after new type system') + if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') @@ -478,40 +463,28 @@ class OurVisitor: func = module.functions[struct_constructor.name] elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: - if not isinstance(exp_type, (TypeFloat32, TypeFloat64, )): - _raise_static_error(node, f'Cannot make {node.func.id} result in {codestyle.type_(exp_type)}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( - exp_type, 'sqrt', self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), ) elif node.func.id == 'u32': - if not isinstance(exp_type, TypeUInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') # FIXME: This is a stub, proper casting is todo return UnaryOp( - exp_type, 'cast', self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), ) elif node.func.id == 'len': - if not isinstance(exp_type, TypeInt32): - _raise_static_error(node, f'Cannot make {node.func.id} result in {exp_type}') - if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( - exp_type, 'len', self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), ) @@ -536,6 +509,8 @@ class OurVisitor: if 2 != len(func.posonlyargs): _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') + raise NotImplementedError('TODO: Broken after new type system') + if exp_type.__class__ != func.returns.__class__: _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') @@ -546,7 +521,6 @@ class OurVisitor: _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') return Fold( - exp_type, Fold.Direction.LEFT, func, self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]), @@ -571,7 +545,7 @@ class OurVisitor: ) return result - def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Attribute) -> Expression: + def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: del module del function @@ -594,15 +568,12 @@ class OurVisitor: if member is None: _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') - if exp_type != member.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}.{member.name} is actually {codestyle.type_(member.type)}') - return AccessStructMember( - VariableReference(node_typ, param), + VariableReference(param), member, ) - def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Subscript) -> Expression: + def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: if not isinstance(node.value, ast.Name): _raise_static_error(node, 'Must reference a name') @@ -616,11 +587,11 @@ class OurVisitor: if node.value.id in our_locals: param = our_locals[node.value.id] node_typ = param.type - varref = VariableReference(param.type, param) + varref = VariableReference(param) elif node.value.id in module.constant_defs: constant_def = module.constant_defs[node.value.id] node_typ = constant_def.type - varref = ModuleConstantReference(node_typ, constant_def) + varref = ModuleConstantReference(constant_def) else: _raise_static_error(node, f'Undefined variable {node.value.id}') @@ -629,15 +600,10 @@ class OurVisitor: ) if isinstance(node_typ, TypeBytes): - t_u8 = module.types['u8'] - if exp_type != t_u8: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{codestyle.expression(slice_expr)}] is actually {codestyle.type_(t_u8)}') - if isinstance(varref, ModuleConstantReference): raise NotImplementedError(f'{node} from module constant') return AccessBytesIndex( - t_u8, varref, slice_expr, ) @@ -655,8 +621,6 @@ class OurVisitor: _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') tuple_member = node_typ.members[idx] - if exp_type != tuple_member.type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(tuple_member.type)}') if isinstance(varref, ModuleConstantReference): raise NotImplementedError(f'{node} from module constant') @@ -667,9 +631,6 @@ class OurVisitor: ) if isinstance(node_typ, TypeStaticArray): - if exp_type != node_typ.member_type: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {node.value.id}[{idx}] is actually {codestyle.type_(node_typ.member_type)}') - if not isinstance(slice_expr, ConstantPrimitive): return AccessStaticArrayMember( varref, diff --git a/phasm/stdlib/alloc.py b/phasm/stdlib/alloc.py index 2761bfb..8c5742d 100644 --- a/phasm/stdlib/alloc.py +++ b/phasm/stdlib/alloc.py @@ -26,7 +26,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32: g.i32.const(0) g.return_() - del alloc_size # TODO + del alloc_size # TODO: Actual implement using a previously freed block g.unreachable() return i32('return') # To satisfy mypy diff --git a/phasm/typer.py b/phasm/typer.py index de55ef3..1230d01 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -66,19 +66,27 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': return inp.variable.type_var if isinstance(inp, ourlang.UnaryOp): + # TODO: Simplified version if inp.operator not in ('sqrt', ): raise NotImplementedError(expression, inp, inp.operator) right = expression(ctx, inp.right) + + inp.type_var = right + return right if isinstance(inp, ourlang.BinaryOp): + # TODO: Simplified version if inp.operator not in ('+', '-', '*', '|', '&', '^'): raise NotImplementedError(expression, inp, inp.operator) left = expression(ctx, inp.left) right = expression(ctx, inp.right) ctx.unify(left, right) + + inp.type_var = left + return left if isinstance(inp, ourlang.FunctionCall): @@ -94,6 +102,9 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.ModuleConstantReference): assert inp.definition.type_var is not None + + inp.type_var = inp.definition.type_var + return inp.definition.type_var raise NotImplementedError(expression, inp) @@ -133,30 +144,35 @@ def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> Type result = ctx.new_var() if isinstance(inp, typing.TypeUInt8): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintSigned(False)) result.add_location(location) return result if isinstance(inp, typing.TypeUInt32): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(False)) result.add_location(location) return result if isinstance(inp, typing.TypeUInt64): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(False)) result.add_location(location) return result if isinstance(inp, typing.TypeInt32): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(True)) result.add_location(location) return result if isinstance(inp, typing.TypeInt64): + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(True)) result.add_location(location) diff --git a/phasm/typing.py b/phasm/typing.py index 0a8982b..94bb602 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -207,23 +207,45 @@ class TypeStruct(TypeBase): ## NEW STUFF BELOW +# This error can also mean that the type somewhere forgot to write a type +# back to the AST. If so, we need to fix the typer. +ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method' + + class TypingNarrowProtoError(TypingError): - pass + """ + A proto error when trying to narrow two types + + This gets turned into a TypingNarrowError by the unify method + """ + # FIXME: Use consistent naming for unify / narrow / entangle class TypingNarrowError(TypingError): + """ + An error when trying to unify two Type Variables + """ def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None: super().__init__( f'Cannot narrow types {l} and {r}: {msg}' ) class TypeConstraintBase: + """ + Base class for classes implementing a contraint on a type + """ def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': raise NotImplementedError('narrow', self, other) class TypeConstraintPrimitive(TypeConstraintBase): + """ + This contraint on a type defines its primitive shape + """ __slots__ = ('primitive', ) class Primitive(enum.Enum): + """ + The primitive ID + """ INT = 0 FLOAT = 1 @@ -245,6 +267,10 @@ class TypeConstraintPrimitive(TypeConstraintBase): return f'Primitive={self.primitive.name}' class TypeConstraintSigned(TypeConstraintBase): + """ + Contraint on whether a signed value can be used or not, or whether + a value can be used in a signed expression + """ __slots__ = ('signed', ) signed: Optional[bool] @@ -270,6 +296,9 @@ class TypeConstraintSigned(TypeConstraintBase): return f'Signed={self.signed}' class TypeConstraintBitWidth(TypeConstraintBase): + """ + Contraint on how many bits an expression has or can possibly have + """ __slots__ = ('minb', 'maxb', ) minb: int @@ -301,6 +330,10 @@ class TypeConstraintBitWidth(TypeConstraintBase): return f'BitWidth={self.minb}..{self.maxb}' class TypeVar: + """ + A type variable + """ + # FIXME: Explain the type system __slots__ = ('ctx', 'ctx_id', ) ctx: 'Context' @@ -331,6 +364,9 @@ class TypeVar: ) class Context: + """ + The context for a collection of type variables + """ def __init__(self) -> None: # Variables are unified (or entangled, if you will) # that means that each TypeVar within a context has an ID, @@ -399,6 +435,10 @@ class Context: del self.var_locations[r_ctx_id] def simplify(inp: TypeVar) -> Optional[str]: + """ + Simplifies a TypeVar into a string that wasm can work with + and users can recognize + """ tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned) diff --git a/phasm/wasmeasy.py b/phasm/wasmeasy.py index d0cf358..01c5d0c 100644 --- a/phasm/wasmeasy.py +++ b/phasm/wasmeasy.py @@ -1,7 +1,7 @@ """ Helper functions to quickly generate WASM code """ -from typing import Any, Dict, List, Optional, Type +from typing import List, Optional import functools diff --git a/pylintrc b/pylintrc index 3759e6b..82948bb 100644 --- a/pylintrc +++ b/pylintrc @@ -1,5 +1,5 @@ [MASTER] -disable=C0103,C0122,R0903,R0911,R0912,R0913,R0915,R1710,W0223 +disable=C0103,C0122,R0902,R0903,R0911,R0912,R0913,R0915,R1710,W0223 max-line-length=180 -- 2.49.0 From 07c0688d1b9aec0870500f611aec8a5ef8e87634 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 17 Sep 2022 20:50:06 +0200 Subject: [PATCH 10/18] Ripping out old type system. Will have to reimplement bytes, static array, tuple and struct. --- phasm/codestyle.py | 13 +- phasm/compiler.py | 234 ++++++++++----------- phasm/ourlang.py | 118 +++++------ phasm/parser.py | 513 ++++++++++++++++++++++----------------------- phasm/typer.py | 71 +------ phasm/typing.py | 156 +++++++------- 6 files changed, 504 insertions(+), 601 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 334b3eb..04ee043 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -33,7 +33,7 @@ def struct_definition(inp: typing.TypeStruct) -> str: Render: TypeStruct's definition """ result = f'class {inp.name}:\n' - for mem in inp.members: + for mem in inp.members: # TODO: Broken after new type system raise NotImplementedError('Structs broken after new type system') # result += f' {mem.name}: {type_(mem.type)}\n' @@ -87,11 +87,12 @@ def expression(inp: ourlang.Expression) -> str: for arg in inp.arguments ) - if isinstance(inp.function, ourlang.StructConstructor): - return f'{inp.function.struct.name}({args})' - - if isinstance(inp.function, ourlang.TupleConstructor): - return f'({args}, )' + # TODO: Broken after new type system + # if isinstance(inp.function, ourlang.StructConstructor): + # return f'{inp.function.struct.name}({args})' + # + # if isinstance(inp.function, ourlang.TupleConstructor): + # return f'({args}, )' return f'{inp.function.name}({args})' diff --git a/phasm/compiler.py b/phasm/compiler.py index edd9066..7f6d53f 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -1,7 +1,7 @@ """ This module contains the code to convert parsed Ourlang into WebAssembly code """ -from typing import List +from typing import List, Optional import struct @@ -14,19 +14,6 @@ from .stdlib import alloc as stdlib_alloc from .stdlib import types as stdlib_types from .wasmgenerator import Generator as WasmGenerator -LOAD_STORE_TYPE_MAP = { - typing.TypeUInt8: 'i32', - typing.TypeUInt32: 'i32', - typing.TypeUInt64: 'i64', - typing.TypeInt32: 'i32', - typing.TypeInt64: 'i64', - typing.TypeFloat32: 'f32', - typing.TypeFloat64: 'f64', -} -""" -When generating code, we sometimes need to load or store simple values -""" - def phasm_compile(inp: ourlang.Module) -> wasm.Module: """ Public method for compiling a parsed Phasm module into @@ -34,42 +21,44 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module: """ return module(inp) -def type_(inp: typing.TypeBase) -> wasm.WasmType: +def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType: """ Compile: type """ - if isinstance(inp, typing.TypeNone): - return wasm.WasmTypeNone() + assert inp is not None, typing.ASSERTION_ERROR - if isinstance(inp, typing.TypeUInt8): + mtyp = typing.simplify(inp) + + if mtyp == 'u8': # WebAssembly has only support for 32 and 64 bits # So we need to store more memory per byte return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeUInt32): + if mtyp == 'u32': return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeUInt64): + if mtyp == 'u64': return wasm.WasmTypeInt64() - if isinstance(inp, typing.TypeInt32): + if mtyp == 'i32': return wasm.WasmTypeInt32() - if isinstance(inp, typing.TypeInt64): + if mtyp == 'i64': return wasm.WasmTypeInt64() - if isinstance(inp, typing.TypeFloat32): + if mtyp == 'f32': return wasm.WasmTypeFloat32() - if isinstance(inp, typing.TypeFloat64): + if mtyp == 'f64': return wasm.WasmTypeFloat64() - if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): - # Structs and tuples are passed as pointer - # And pointers are i32 - return wasm.WasmTypeInt32() + # TODO: Broken after new type system + # if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): + # # Structs and tuples are passed as pointer + # # And pointers are i32 + # return wasm.WasmTypeInt32() - raise NotImplementedError(type_, inp) + raise NotImplementedError(inp, mtyp) # Operators that work for i32, i64, f32, f64 OPERATOR_MAP = { @@ -268,47 +257,47 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: # wgn.call(stdlib_types.__subscript_bytes__) # return - if isinstance(inp, ourlang.AccessStructMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.member) - - expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - return - - if isinstance(inp, ourlang.AccessTupleMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.member) - - expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - return - - if isinstance(inp, ourlang.AccessStaticArrayMember): - mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of static arrays - raise NotImplementedError(expression, inp, inp.member) - - if isinstance(inp.member, typing.TypeStaticArrayMember): - expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - return - - expression(wgn, inp.varref) - expression(wgn, inp.member) - wgn.i32.const(inp.static_array.member_type.alloc_size()) - wgn.i32.mul() - wgn.i32.add() - wgn.add_statement(f'{mtyp}.load') - return + # if isinstance(inp, ourlang.AccessStructMember): + # mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) + # if mtyp is None: + # # In the future might extend this by having structs or tuples + # # as members of struct or tuples + # raise NotImplementedError(expression, inp, inp.member) + # + # expression(wgn, inp.varref) + # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) + # return + # + # if isinstance(inp, ourlang.AccessTupleMember): + # mtyp = LOAD_STORE_TYPE_MAP.get(inp.member.type.__class__) + # if mtyp is None: + # # In the future might extend this by having structs or tuples + # # as members of struct or tuples + # raise NotImplementedError(expression, inp, inp.member) + # + # expression(wgn, inp.varref) + # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) + # return + # + # if isinstance(inp, ourlang.AccessStaticArrayMember): + # mtyp = LOAD_STORE_TYPE_MAP.get(inp.static_array.member_type.__class__) + # if mtyp is None: + # # In the future might extend this by having structs or tuples + # # as members of static arrays + # raise NotImplementedError(expression, inp, inp.member) + # + # if isinstance(inp.member, typing.TypeStaticArrayMember): + # expression(wgn, inp.varref) + # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) + # return + # + # expression(wgn, inp.varref) + # expression(wgn, inp.member) + # wgn.i32.const(inp.static_array.member_type.alloc_size()) + # wgn.i32.mul() + # wgn.i32.add() + # wgn.add_statement(f'{mtyp}.load') + # return if isinstance(inp, ourlang.Fold): expression_fold(wgn, inp) @@ -472,7 +461,7 @@ def function_argument(inp: ourlang.FunctionParam) -> wasm.Param: """ Compile: function argument """ - return (inp.name, type_(inp.type), ) + return (inp.name, type_var(inp.type_var), ) def import_(inp: ourlang.Function) -> wasm.Import: """ @@ -488,7 +477,7 @@ def import_(inp: ourlang.Function) -> wasm.Import: function_argument(x) for x in inp.posonlyargs ], - type_(inp.returns) + type_var(inp.returns_type_var) ) def function(inp: ourlang.Function) -> wasm.Function: @@ -499,10 +488,10 @@ def function(inp: ourlang.Function) -> wasm.Function: wgn = WasmGenerator() - if isinstance(inp, ourlang.TupleConstructor): - _generate_tuple_constructor(wgn, inp) - elif isinstance(inp, ourlang.StructConstructor): - _generate_struct_constructor(wgn, inp) + if False: # TODO: isinstance(inp, ourlang.TupleConstructor): + pass # _generate_tuple_constructor(wgn, inp) + elif False: # TODO: isinstance(inp, ourlang.StructConstructor): + pass # _generate_struct_constructor(wgn, inp) else: for stat in inp.statements: statement(wgn, stat) @@ -518,7 +507,7 @@ def function(inp: ourlang.Function) -> wasm.Function: (k, v.wasm_type(), ) for k, v in wgn.locals.items() ], - type_(inp.returns), + type_var(inp.returns_type_var), wgn.statements ) @@ -660,48 +649,49 @@ def module(inp: ourlang.Module) -> wasm.Module: return result -def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None: - tmp_var = wgn.temp_var_i32('tuple_adr') - - # Allocated the required amounts of bytes in memory - wgn.i32.const(inp.tuple.alloc_size()) - wgn.call(stdlib_alloc.__alloc__) - wgn.local.set(tmp_var) - - # Store each member individually - for member in inp.tuple.members: - mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, member) - - wgn.local.get(tmp_var) - wgn.add_statement('local.get', f'$arg{member.idx}') - wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) - - # Return the allocated address - wgn.local.get(tmp_var) - -def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: - tmp_var = wgn.temp_var_i32('struct_adr') - - # Allocated the required amounts of bytes in memory - wgn.i32.const(inp.struct.alloc_size()) - wgn.call(stdlib_alloc.__alloc__) - wgn.local.set(tmp_var) - - # Store each member individually - for member in inp.struct.members: - mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, member) - - wgn.local.get(tmp_var) - wgn.add_statement('local.get', f'${member.name}') - wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) - - # Return the allocated address - wgn.local.get(tmp_var) +# TODO: Broken after new type system +# def _generate_tuple_constructor(wgn: WasmGenerator, inp: ourlang.TupleConstructor) -> None: +# tmp_var = wgn.temp_var_i32('tuple_adr') +# +# # Allocated the required amounts of bytes in memory +# wgn.i32.const(inp.tuple.alloc_size()) +# wgn.call(stdlib_alloc.__alloc__) +# wgn.local.set(tmp_var) +# +# # Store each member individually +# for member in inp.tuple.members: +# mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) +# if mtyp is None: +# # In the future might extend this by having structs or tuples +# # as members of struct or tuples +# raise NotImplementedError(expression, inp, member) +# +# wgn.local.get(tmp_var) +# wgn.add_statement('local.get', f'$arg{member.idx}') +# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) +# +# # Return the allocated address +# wgn.local.get(tmp_var) +# +# def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: +# tmp_var = wgn.temp_var_i32('struct_adr') +# +# # Allocated the required amounts of bytes in memory +# wgn.i32.const(inp.struct.alloc_size()) +# wgn.call(stdlib_alloc.__alloc__) +# wgn.local.set(tmp_var) +# +# # Store each member individually +# for member in inp.struct.members: +# mtyp = LOAD_STORE_TYPE_MAP.get(member.type.__class__) +# if mtyp is None: +# # In the future might extend this by having structs or tuples +# # as members of struct or tuples +# raise NotImplementedError(expression, inp, member) +# +# wgn.local.get(tmp_var) +# wgn.add_statement('local.get', f'${member.name}') +# wgn.add_statement(f'{mtyp}.store', 'offset=' + str(member.offset)) +# +# # Return the allocated address +# wgn.local.get(tmp_var) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index ebc13d7..9e16e4b 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -11,11 +11,6 @@ WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) from .typing import ( - TypeBase, - TypeNone, - TypeUInt8, TypeUInt32, TypeUInt64, - TypeInt32, TypeInt64, - TypeFloat32, TypeFloat64, TypeBytes, TypeTuple, TypeTupleMember, TypeStaticArray, TypeStaticArrayMember, @@ -280,29 +275,29 @@ class FunctionParam: """ A parameter for a Function """ - __slots__ = ('name', 'type', 'type_var', ) + __slots__ = ('name', 'type_str', 'type_var', ) name: str - type: TypeBase + type_str: str type_var: Optional[TypeVar] - def __init__(self, name: str, type_: TypeBase) -> None: + def __init__(self, name: str, type_str: str) -> None: self.name = name - self.type = type_ + self.type_str = type_str self.type_var = None class Function: """ A function processes input and produces output """ - __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns', 'returns_type_var', 'posonlyargs', ) + __slots__ = ('name', 'lineno', 'exported', 'imported', 'statements', 'returns_str', 'returns_type_var', 'posonlyargs', ) name: str lineno: int exported: bool imported: bool statements: List[Statement] - returns: TypeBase + returns_str: str returns_type_var: Optional[TypeVar] posonlyargs: List[FunctionParam] @@ -312,68 +307,67 @@ class Function: self.exported = False self.imported = False self.statements = [] - self.returns = TypeNone() + self.returns_str = 'None' self.returns_type_var = None self.posonlyargs = [] -class StructConstructor(Function): - """ - The constructor method for a struct - - A function will generated to instantiate a struct. The arguments - will be the defaults - """ - __slots__ = ('struct', ) - - struct: TypeStruct - - def __init__(self, struct: TypeStruct) -> None: - super().__init__(f'@{struct.name}@__init___@', -1) - - self.returns = struct - - for mem in struct.members: - self.posonlyargs.append(FunctionParam(mem.name, mem.type, )) - - self.struct = struct - -class TupleConstructor(Function): - """ - The constructor method for a tuple - """ - __slots__ = ('tuple', ) - - tuple: TypeTuple - - def __init__(self, tuple_: TypeTuple) -> None: - name = tuple_.render_internal_name() - - super().__init__(f'@{name}@__init___@', -1) - - self.returns = tuple_ - - for mem in tuple_.members: - self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, )) - - self.tuple = tuple_ +# TODO: Broken after new type system +# class StructConstructor(Function): +# """ +# The constructor method for a struct +# +# A function will generated to instantiate a struct. The arguments +# will be the defaults +# """ +# __slots__ = ('struct', ) +# +# struct: TypeStruct +# +# def __init__(self, struct: TypeStruct) -> None: +# super().__init__(f'@{struct.name}@__init___@', -1) +# +# self.returns = struct +# +# for mem in struct.members: +# self.posonlyargs.append(FunctionParam(mem.name, mem.type, )) +# +# self.struct = struct +# +# class TupleConstructor(Function): +# """ +# The constructor method for a tuple +# """ +# __slots__ = ('tuple', ) +# +# tuple: TypeTuple +# +# def __init__(self, tuple_: TypeTuple) -> None: +# name = tuple_.render_internal_name() +# +# super().__init__(f'@{name}@__init___@', -1) +# +# self.returns = tuple_ +# +# for mem in tuple_.members: +# self.posonlyargs.append(FunctionParam(f'arg{mem.idx}', mem.type, )) +# +# self.tuple = tuple_ class ModuleConstantDef: """ A constant definition within a module """ - __slots__ = ('name', 'lineno', 'type', 'type_var', 'constant', 'data_block', ) + __slots__ = ('name', 'lineno', 'type_var', 'constant', 'data_block', ) name: str lineno: int - type: TypeBase type_var: Optional[TypeVar] constant: Constant data_block: Optional['ModuleDataBlock'] - def __init__(self, name: str, lineno: int, type_: TypeBase, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: + def __init__(self, name: str, lineno: int, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: self.name = name self.lineno = lineno - self.type = type_ self.type_var = None self.constant = constant self.data_block = data_block @@ -409,23 +403,11 @@ class Module: __slots__ = ('data', 'types', 'structs', 'constant_defs', 'functions',) data: ModuleData - types: Dict[str, TypeBase] structs: Dict[str, TypeStruct] constant_defs: Dict[str, ModuleConstantDef] functions: Dict[str, Function] def __init__(self) -> None: - self.types = { - 'None': TypeNone(), - 'u8': TypeUInt8(), - 'u32': TypeUInt32(), - 'u64': TypeUInt64(), - 'i32': TypeInt32(), - 'i64': TypeInt64(), - 'f32': TypeFloat32(), - 'f64': TypeFloat64(), - 'bytes': TypeBytes(), - } self.data = ModuleData() self.structs = {} self.constant_defs = {} diff --git a/phasm/parser.py b/phasm/parser.py index d8e94af..6fe96fc 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -6,8 +6,6 @@ from typing import Any, Dict, NoReturn, Union import ast from .typing import ( - TypeBase, - TypeBytes, TypeStruct, TypeStructMember, TypeTuple, @@ -16,7 +14,6 @@ from .typing import ( TypeStaticArrayMember, ) -from . import codestyle from .exceptions import StaticError from .ourlang import ( WEBASSEMBLY_BUILDIN_FLOAT_OPS, @@ -30,7 +27,7 @@ from .ourlang import ( ConstantPrimitive, ConstantTuple, ConstantStaticArray, FunctionCall, - StructConstructor, TupleConstructor, + # StructConstructor, TupleConstructor, UnaryOp, VariableReference, Fold, ModuleConstantReference, @@ -86,15 +83,16 @@ class OurVisitor: module.constant_defs[res.name] = res - if isinstance(res, TypeStruct): - if res.name in module.structs: - raise StaticError( - f'{res.name} already defined on line {module.structs[res.name].lineno}' - ) - - module.structs[res.name] = res - constructor = StructConstructor(res) - module.functions[constructor.name] = constructor + # TODO: Broken after type system + # if isinstance(res, TypeStruct): + # if res.name in module.structs: + # raise StaticError( + # f'{res.name} already defined on line {module.structs[res.name].lineno}' + # ) + # + # module.structs[res.name] = res + # constructor = StructConstructor(res) + # module.functions[constructor.name] = constructor if isinstance(res, Function): if res.name in module.functions: @@ -158,7 +156,7 @@ class OurVisitor: function.imported = True if node.returns: - function.returns = self.visit_type(module, node.returns) + function.returns_str = self.visit_type(module, node.returns) _not_implemented(not node.type_comment, 'FunctionDef.type_comment') @@ -186,6 +184,7 @@ class OurVisitor: if stmt.simple != 1: raise NotImplementedError('Class with non-simple arguments') + raise NotImplementedError('TODO: Broken after new type system') member = TypeStructMember(stmt.target.id, self.visit_type(module, stmt.annotation), offset) struct.members.append(member) @@ -199,74 +198,72 @@ class OurVisitor: if not isinstance(node.target.ctx, ast.Store): _raise_static_error(node, 'Must be load context') - exp_type = self.visit_type(module, node.annotation) - if isinstance(node.value, ast.Constant): return ModuleConstantDef( node.target.id, node.lineno, - exp_type, self.visit_Module_Constant(module, node.value), None, ) - if isinstance(exp_type, TypeTuple): - if not isinstance(node.value, ast.Tuple): - _raise_static_error(node, 'Must be tuple') + raise NotImplementedError('TODO: Broken after new typing system') - if len(exp_type.members) != len(node.value.elts): - _raise_static_error(node, 'Invalid number of tuple values') - - tuple_data = [ - self.visit_Module_Constant(module, arg_node) - for arg_node, mem in zip(node.value.elts, exp_type.members) - if isinstance(arg_node, ast.Constant) - ] - if len(exp_type.members) != len(tuple_data): - _raise_static_error(node, 'Tuple arguments must be constants') - - # Allocate the data - data_block = ModuleDataBlock(tuple_data) - module.data.blocks.append(data_block) - - # Then return the constant as a pointer - return ModuleConstantDef( - node.target.id, - node.lineno, - exp_type, - ConstantTuple(tuple_data), - data_block, - ) - - if isinstance(exp_type, TypeStaticArray): - if not isinstance(node.value, ast.Tuple): - _raise_static_error(node, 'Must be static array') - - if len(exp_type.members) != len(node.value.elts): - _raise_static_error(node, 'Invalid number of static array values') - - static_array_data = [ - self.visit_Module_Constant(module, arg_node) - for arg_node in node.value.elts - if isinstance(arg_node, ast.Constant) - ] - if len(exp_type.members) != len(static_array_data): - _raise_static_error(node, 'Static array arguments must be constants') - - # Allocate the data - data_block = ModuleDataBlock(static_array_data) - module.data.blocks.append(data_block) - - # Then return the constant as a pointer - return ModuleConstantDef( - node.target.id, - node.lineno, - exp_type, - ConstantStaticArray(static_array_data), - data_block, - ) - - raise NotImplementedError(f'{node} on Module AnnAssign') + # if isinstance(exp_type, TypeTuple): + # if not isinstance(node.value, ast.Tuple): + # _raise_static_error(node, 'Must be tuple') + # + # if len(exp_type.members) != len(node.value.elts): + # _raise_static_error(node, 'Invalid number of tuple values') + # + # tuple_data = [ + # self.visit_Module_Constant(module, arg_node) + # for arg_node, mem in zip(node.value.elts, exp_type.members) + # if isinstance(arg_node, ast.Constant) + # ] + # if len(exp_type.members) != len(tuple_data): + # _raise_static_error(node, 'Tuple arguments must be constants') + # + # # Allocate the data + # data_block = ModuleDataBlock(tuple_data) + # module.data.blocks.append(data_block) + # + # # Then return the constant as a pointer + # return ModuleConstantDef( + # node.target.id, + # node.lineno, + # exp_type, + # ConstantTuple(tuple_data), + # data_block, + # ) + # + # if isinstance(exp_type, TypeStaticArray): + # if not isinstance(node.value, ast.Tuple): + # _raise_static_error(node, 'Must be static array') + # + # if len(exp_type.members) != len(node.value.elts): + # _raise_static_error(node, 'Invalid number of static array values') + # + # static_array_data = [ + # self.visit_Module_Constant(module, arg_node) + # for arg_node in node.value.elts + # if isinstance(arg_node, ast.Constant) + # ] + # if len(exp_type.members) != len(static_array_data): + # _raise_static_error(node, 'Static array arguments must be constants') + # + # # Allocate the data + # data_block = ModuleDataBlock(static_array_data) + # module.data.blocks.append(data_block) + # + # # Then return the constant as a pointer + # return ModuleConstantDef( + # node.target.id, + # node.lineno, + # ConstantStaticArray(static_array_data), + # data_block, + # ) + # + # raise NotImplementedError(f'{node} on Module AnnAssign') def visit_Module_stmt(self, module: Module, node: ast.stmt) -> None: if isinstance(node, ast.FunctionDef): @@ -301,12 +298,12 @@ class OurVisitor: _raise_static_error(node, 'Return must have an argument') return StatementReturn( - self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.value) + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.value) ) if isinstance(node, ast.If): result = StatementIf( - self.visit_Module_FunctionDef_expr(module, function, our_locals, function.returns, node.test) + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.test) ) for stmt in node.body: @@ -326,7 +323,7 @@ class OurVisitor: raise NotImplementedError(f'{node} as stmt in FunctionDef') - def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.expr) -> Expression: + def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression: if isinstance(node, ast.BinOp): if isinstance(node.op, ast.Add): operator = '+' @@ -352,8 +349,8 @@ class OurVisitor: return BinaryOp( operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.right), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right), ) if isinstance(node, ast.UnaryOp): @@ -366,7 +363,7 @@ class OurVisitor: return UnaryOp( operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.operand), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.operand), ) if isinstance(node, ast.Compare): @@ -387,12 +384,12 @@ class OurVisitor: return BinaryOp( operator, - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.left), - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.comparators[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]), ) if isinstance(node, ast.Call): - return self.visit_Module_FunctionDef_Call(module, function, our_locals, exp_type, node) + return self.visit_Module_FunctionDef_Call(module, function, our_locals, node) if isinstance(node, ast.Constant): return self.visit_Module_Constant( @@ -426,29 +423,29 @@ class OurVisitor: if isinstance(node, ast.Tuple): raise NotImplementedError('TODO: Broken after new type system') - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if isinstance(exp_type, TypeTuple): - if len(exp_type.members) != len(node.elts): - _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') - - tuple_constructor = TupleConstructor(exp_type) - - func = module.functions[tuple_constructor.name] - - result = FunctionCall(func) - result.arguments = [ - self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) - for arg_node, mem in zip(node.elts, exp_type.members) - ] - return result - - _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple') + # if not isinstance(node.ctx, ast.Load): + # _raise_static_error(node, 'Must be load context') + # + # if isinstance(exp_type, TypeTuple): + # if len(exp_type.members) != len(node.elts): + # _raise_static_error(node, f'Expression is expecting a tuple of size {len(exp_type.members)}, but {len(node.elts)} are given') + # + # tuple_constructor = TupleConstructor(exp_type) + # + # func = module.functions[tuple_constructor.name] + # + # result = FunctionCall(func) + # result.arguments = [ + # self.visit_Module_FunctionDef_expr(module, function, our_locals, mem.type, arg_node) + # for arg_node, mem in zip(node.elts, exp_type.members) + # ] + # return result + # + # _raise_static_error(node, f'Expression is expecting a {codestyle.type_(exp_type)}, not a tuple') raise NotImplementedError(f'{node} as expr in FunctionDef') - def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, exp_type: TypeBase, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: + def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? @@ -458,17 +455,18 @@ class OurVisitor: _raise_static_error(node, 'Must be load context') if node.func.id in module.structs: - struct = module.structs[node.func.id] - struct_constructor = StructConstructor(struct) - - func = module.functions[struct_constructor.name] + raise NotImplementedError('TODO: Broken after new type system') + # struct = module.structs[node.func.id] + # struct_constructor = StructConstructor(struct) + # + # func = module.functions[struct_constructor.name] elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') return UnaryOp( 'sqrt', - self.visit_Module_FunctionDef_expr(module, function, our_locals, exp_type, node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'u32': if 1 != len(node.args): @@ -478,7 +476,7 @@ class OurVisitor: return UnaryOp( 'cast', - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['u8'], node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'len': if 1 != len(node.args): @@ -486,7 +484,7 @@ class OurVisitor: return UnaryOp( 'len', - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[0]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[0]), ) elif node.func.id == 'foldl': # TODO: This should a much more generic function! @@ -511,20 +509,11 @@ class OurVisitor: raise NotImplementedError('TODO: Broken after new type system') - if exp_type.__class__ != func.returns.__class__: - _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') - - if func.returns.__class__ != func.posonlyargs[0].type.__class__: - _raise_static_error(node, f'Expected a foldable function, {func.name} returns a {codestyle.type_(func.returns)} but expects a {codestyle.type_(func.posonlyargs[0].type)}') - - if module.types['u8'].__class__ != func.posonlyargs[1].type.__class__: - _raise_static_error(node, 'Only folding over bytes (u8) is supported at this time') - return Fold( Fold.Direction.LEFT, func, - self.visit_Module_FunctionDef_expr(module, function, our_locals, func.returns, node.args[1]), - self.visit_Module_FunctionDef_expr(module, function, our_locals, module.types['bytes'], node.args[2]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), + self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), ) else: if node.func.id not in module.functions: @@ -532,20 +521,18 @@ class OurVisitor: func = module.functions[node.func.id] - # if func.returns != exp_type: - # _raise_static_error(node, f'Expected {codestyle.type_(exp_type)}, {func.name} actually returns {codestyle.type_(func.returns)}') - if len(func.posonlyargs) != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires {len(func.posonlyargs)} arguments but {len(node.args)} are given') result = FunctionCall(func) result.arguments.extend( - self.visit_Module_FunctionDef_expr(module, function, our_locals, param.type, arg_expr) + self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) for arg_expr, param in zip(node.args, func.posonlyargs) ) return result def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: + raise NotImplementedError('Broken after new type system') del module del function @@ -574,87 +561,89 @@ class OurVisitor: ) def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: - if not isinstance(node.value, ast.Name): - _raise_static_error(node, 'Must reference a name') + raise NotImplementedError('TODO: Broken after new type system') - if not isinstance(node.slice, ast.Index): - _raise_static_error(node, 'Must subscript using an index') - - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - varref: Union[ModuleConstantReference, VariableReference] - if node.value.id in our_locals: - param = our_locals[node.value.id] - node_typ = param.type - varref = VariableReference(param) - elif node.value.id in module.constant_defs: - constant_def = module.constant_defs[node.value.id] - node_typ = constant_def.type - varref = ModuleConstantReference(constant_def) - else: - _raise_static_error(node, f'Undefined variable {node.value.id}') - - slice_expr = self.visit_Module_FunctionDef_expr( - module, function, our_locals, module.types['u32'], node.slice.value, - ) - - if isinstance(node_typ, TypeBytes): - if isinstance(varref, ModuleConstantReference): - raise NotImplementedError(f'{node} from module constant') - - return AccessBytesIndex( - varref, - slice_expr, - ) - - if isinstance(node_typ, TypeTuple): - if not isinstance(slice_expr, ConstantPrimitive): - _raise_static_error(node, 'Must subscript using a constant index') - - idx = slice_expr.value - - if not isinstance(idx, int): - _raise_static_error(node, 'Must subscript using a constant integer index') - - if not (0 <= idx < len(node_typ.members)): - _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') - - tuple_member = node_typ.members[idx] - - if isinstance(varref, ModuleConstantReference): - raise NotImplementedError(f'{node} from module constant') - - return AccessTupleMember( - varref, - tuple_member, - ) - - if isinstance(node_typ, TypeStaticArray): - if not isinstance(slice_expr, ConstantPrimitive): - return AccessStaticArrayMember( - varref, - node_typ, - slice_expr, - ) - - idx = slice_expr.value - - if not isinstance(idx, int): - _raise_static_error(node, 'Must subscript using an integer index') - - if not (0 <= idx < len(node_typ.members)): - _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') - - static_array_member = node_typ.members[idx] - - return AccessStaticArrayMember( - varref, - node_typ, - static_array_member, - ) - - _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') + # if not isinstance(node.value, ast.Name): + # _raise_static_error(node, 'Must reference a name') + # + # if not isinstance(node.slice, ast.Index): + # _raise_static_error(node, 'Must subscript using an index') + # + # if not isinstance(node.ctx, ast.Load): + # _raise_static_error(node, 'Must be load context') + # + # varref: Union[ModuleConstantReference, VariableReference] + # if node.value.id in our_locals: + # param = our_locals[node.value.id] + # node_typ = param.type + # varref = VariableReference(param) + # elif node.value.id in module.constant_defs: + # constant_def = module.constant_defs[node.value.id] + # node_typ = constant_def.type + # varref = ModuleConstantReference(constant_def) + # else: + # _raise_static_error(node, f'Undefined variable {node.value.id}') + # + # slice_expr = self.visit_Module_FunctionDef_expr( + # module, function, our_locals, node.slice.value, + # ) + # + # if isinstance(node_typ, TypeBytes): + # if isinstance(varref, ModuleConstantReference): + # raise NotImplementedError(f'{node} from module constant') + # + # return AccessBytesIndex( + # varref, + # slice_expr, + # ) + # + # if isinstance(node_typ, TypeTuple): + # if not isinstance(slice_expr, ConstantPrimitive): + # _raise_static_error(node, 'Must subscript using a constant index') + # + # idx = slice_expr.value + # + # if not isinstance(idx, int): + # _raise_static_error(node, 'Must subscript using a constant integer index') + # + # if not (0 <= idx < len(node_typ.members)): + # _raise_static_error(node, f'Index {idx} out of bounds for tuple {node.value.id}') + # + # tuple_member = node_typ.members[idx] + # + # if isinstance(varref, ModuleConstantReference): + # raise NotImplementedError(f'{node} from module constant') + # + # return AccessTupleMember( + # varref, + # tuple_member, + # ) + # + # if isinstance(node_typ, TypeStaticArray): + # if not isinstance(slice_expr, ConstantPrimitive): + # return AccessStaticArrayMember( + # varref, + # node_typ, + # slice_expr, + # ) + # + # idx = slice_expr.value + # + # if not isinstance(idx, int): + # _raise_static_error(node, 'Must subscript using an integer index') + # + # if not (0 <= idx < len(node_typ.members)): + # _raise_static_error(node, f'Index {idx} out of bounds for static array {node.value.id}') + # + # static_array_member = node_typ.members[idx] + # + # return AccessStaticArrayMember( + # varref, + # node_typ, + # static_array_member, + # ) + # + # _raise_static_error(node, f'Cannot take index of {node_typ} {node.value.id}') def visit_Module_Constant(self, module: Module, node: ast.Constant) -> ConstantPrimitive: del module @@ -666,10 +655,10 @@ class OurVisitor: raise NotImplementedError(f'{node.value} as constant') - def visit_type(self, module: Module, node: ast.expr) -> TypeBase: + def visit_type(self, module: Module, node: ast.expr) -> str: if isinstance(node, ast.Constant): if node.value is None: - return module.types['None'] + return 'None' _raise_static_error(node, f'Unrecognized type {node.value}') @@ -677,8 +666,10 @@ class OurVisitor: if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.id in module.types: - return module.types[node.id] + if node.id in ('u8', 'u32', 'u64', 'i32', 'i64', 'f32', 'f64'): # FIXME: Source this list somewhere + return node.id + + raise NotImplementedError('TODO: Broken after type system') if node.id in module.structs: return module.structs[node.id] @@ -686,61 +677,65 @@ class OurVisitor: _raise_static_error(node, f'Unrecognized type {node.id}') if isinstance(node, ast.Subscript): - if not isinstance(node.value, ast.Name): - _raise_static_error(node, 'Must be name') - if not isinstance(node.slice, ast.Index): - _raise_static_error(node, 'Must subscript using an index') - if not isinstance(node.slice.value, ast.Constant): - _raise_static_error(node, 'Must subscript using a constant index') - if not isinstance(node.slice.value.value, int): - _raise_static_error(node, 'Must subscript using a constant integer index') - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') + raise NotImplementedError('TODO: Broken after new type system') - if node.value.id in module.types: - member_type = module.types[node.value.id] - else: - _raise_static_error(node, f'Unrecognized type {node.value.id}') - - type_static_array = TypeStaticArray(member_type) - - offset = 0 - - for idx in range(node.slice.value.value): - static_array_member = TypeStaticArrayMember(idx, offset) - - type_static_array.members.append(static_array_member) - offset += member_type.alloc_size() - - key = f'{node.value.id}[{node.slice.value.value}]' - - if key not in module.types: - module.types[key] = type_static_array - - return module.types[key] + # if not isinstance(node.value, ast.Name): + # _raise_static_error(node, 'Must be name') + # if not isinstance(node.slice, ast.Index): + # _raise_static_error(node, 'Must subscript using an index') + # if not isinstance(node.slice.value, ast.Constant): + # _raise_static_error(node, 'Must subscript using a constant index') + # if not isinstance(node.slice.value.value, int): + # _raise_static_error(node, 'Must subscript using a constant integer index') + # if not isinstance(node.ctx, ast.Load): + # _raise_static_error(node, 'Must be load context') + # + # if node.value.id in module.types: + # member_type = module.types[node.value.id] + # else: + # _raise_static_error(node, f'Unrecognized type {node.value.id}') + # + # type_static_array = TypeStaticArray(member_type) + # + # offset = 0 + # + # for idx in range(node.slice.value.value): + # static_array_member = TypeStaticArrayMember(idx, offset) + # + # type_static_array.members.append(static_array_member) + # offset += member_type.alloc_size() + # + # key = f'{node.value.id}[{node.slice.value.value}]' + # + # if key not in module.types: + # module.types[key] = type_static_array + # + # return module.types[key] if isinstance(node, ast.Tuple): - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') + raise NotImplementedError('TODO: Broken after new type system') - type_tuple = TypeTuple() - - offset = 0 - - for idx, elt in enumerate(node.elts): - tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset) - - type_tuple.members.append(tuple_member) - offset += tuple_member.type.alloc_size() - - key = type_tuple.render_internal_name() - - if key not in module.types: - module.types[key] = type_tuple - constructor = TupleConstructor(type_tuple) - module.functions[constructor.name] = constructor - - return module.types[key] + # if not isinstance(node.ctx, ast.Load): + # _raise_static_error(node, 'Must be load context') + # + # type_tuple = TypeTuple() + # + # offset = 0 + # + # for idx, elt in enumerate(node.elts): + # tuple_member = TypeTupleMember(idx, self.visit_type(module, elt), offset) + # + # type_tuple.members.append(tuple_member) + # offset += tuple_member.type.alloc_size() + # + # key = type_tuple.render_internal_name() + # + # if key not in module.types: + # module.types[key] = type_tuple + # constructor = TupleConstructor(type_tuple) + # module.functions[constructor.name] = constructor + # + # return module.types[key] raise NotImplementedError(f'{node} as type') diff --git a/phasm/typer.py b/phasm/typer.py index 1230d01..3041ec3 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -3,12 +3,12 @@ Type checks and enriches the given ast """ from . import ourlang -from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar +from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar, from_str def phasm_type(inp: ourlang.Module) -> None: module(inp) -def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': +def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar: if isinstance(inp, ourlang.ConstantPrimitive): result = ctx.new_var() @@ -57,7 +57,7 @@ def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': raise NotImplementedError(constant, inp) -def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': +def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.Constant): return constant(ctx, inp) @@ -109,7 +109,7 @@ def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': raise NotImplementedError(expression, inp) -def function(ctx: 'Context', inp: ourlang.Function) -> None: +def function(ctx: Context, inp: ourlang.Function) -> None: if len(inp.statements) != 1 or not isinstance(inp.statements[0], ourlang.StatementReturn): raise NotImplementedError('Functions with not just a return statement') typ = expression(ctx, inp.statements[0].value) @@ -117,10 +117,11 @@ def function(ctx: 'Context', inp: ourlang.Function) -> None: assert inp.returns_type_var is not None ctx.unify(inp.returns_type_var, typ) -def module_constant_def(ctx: 'Context', inp: ourlang.ModuleConstantDef) -> None: - inp.type_var = _convert_old_type(ctx, inp.type, inp.name) +def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> None: constant(ctx, inp.constant) + inp.type_var = ctx.new_var() + assert inp.constant.type_var is not None ctx.unify(inp.type_var, inp.constant.type_var) @@ -128,66 +129,12 @@ def module(inp: ourlang.Module) -> None: ctx = Context() for func in inp.functions.values(): - func.returns_type_var = _convert_old_type(ctx, func.returns, f'{func.name}.(returns)') + func.returns_type_var = from_str(ctx, func.returns_str, f'{func.name}.(returns)') for param in func.posonlyargs: - param.type_var = _convert_old_type(ctx, param.type, f'{func.name}.{param.name}') + param.type_var = from_str(ctx, param.type_str, f'{func.name}.{param.name}') for cdef in inp.constant_defs.values(): module_constant_def(ctx, cdef) for func in inp.functions.values(): function(ctx, func) - -from . import typing - -def _convert_old_type(ctx: Context, inp: typing.TypeBase, location: str) -> TypeVar: - result = ctx.new_var() - - if isinstance(inp, typing.TypeUInt8): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) - result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) - return result - - if isinstance(inp, typing.TypeUInt32): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) - return result - - if isinstance(inp, typing.TypeUInt64): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) - return result - - if isinstance(inp, typing.TypeInt32): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_constraint(TypeConstraintSigned(True)) - result.add_location(location) - return result - - if isinstance(inp, typing.TypeInt64): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_constraint(TypeConstraintSigned(True)) - result.add_location(location) - return result - - if isinstance(inp, typing.TypeFloat32): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_location(location) - return result - - if isinstance(inp, typing.TypeFloat64): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_location(location) - return result - - raise NotImplementedError(_convert_old_type, inp) diff --git a/phasm/typing.py b/phasm/typing.py index 94bb602..de92a7b 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -19,88 +19,6 @@ class TypeBase: """ raise NotImplementedError(self, 'alloc_size') -class TypeNone(TypeBase): - """ - The None (or Void) type - """ - __slots__ = () - -class TypeBool(TypeBase): - """ - The boolean type - """ - __slots__ = () - -class TypeUInt8(TypeBase): - """ - The Integer type, unsigned and 8 bits wide - - Note that under the hood we need to use i32 to represent - these values in expressions. So we need to add some operations - to make sure the math checks out. - - So while this does save bytes in memory, it may not actually - speed up or improve your code. - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 # Int32 under the hood - -class TypeUInt32(TypeBase): - """ - The Integer type, unsigned and 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeUInt64(TypeBase): - """ - The Integer type, unsigned and 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - -class TypeInt32(TypeBase): - """ - The Integer type, signed and 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeInt64(TypeBase): - """ - The Integer type, signed and 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - -class TypeFloat32(TypeBase): - """ - The Float type, 32 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 4 - -class TypeFloat64(TypeBase): - """ - The Float type, 64 bits wide - """ - __slots__ = () - - def alloc_size(self) -> int: - return 8 - class TypeBytes(TypeBase): """ The bytes type @@ -207,7 +125,7 @@ class TypeStruct(TypeBase): ## NEW STUFF BELOW -# This error can also mean that the type somewhere forgot to write a type +# This error can also mean that the typer somewhere forgot to write a type # back to the AST. If so, we need to fix the typer. ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method' @@ -392,7 +310,12 @@ class Context: return result - def unify(self, l: 'TypeVar', r: 'TypeVar') -> None: + def unify(self, l: Optional[TypeVar], r: Optional[TypeVar]) -> None: + # FIXME: Write method doc, find out why pylint doesn't error + + assert l is not None, ASSERTION_ERROR + assert r is not None, ASSERTION_ERROR + assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return # Backup some values that we'll overwrite @@ -438,6 +361,8 @@ def simplify(inp: TypeVar) -> Optional[str]: """ Simplifies a TypeVar into a string that wasm can work with and users can recognize + + Should round trip with from_str """ tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) @@ -473,3 +398,66 @@ def simplify(inp: TypeVar) -> Optional[str]: return f'f{tc_bits.minb}' return None + +def from_str(ctx: Context, inp: str, location: str) -> TypeVar: + """ + Creates a new TypeVar from the string + + Should round trip with simplify + + The location is a reference to where you found the string + in the source code. + + This could be conidered part of parsing. Though that would give trouble + with the context creation. + """ + result = ctx.new_var() + + if inp == 'u8': + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + + if inp == 'u32': + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + + if inp == 'u64': + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + + if inp == 'i32': + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location(location) + return result + + if inp == 'i64': + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location(location) + return result + + if inp == 'f32': + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_location(location) + return result + + if inp == 'f64': + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_location(location) + return result + + raise NotImplementedError(from_str, inp) -- 2.49.0 From 906b15c93c58d9197b5b8671f1ee8949f25798b8 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 11:16:34 +0200 Subject: [PATCH 11/18] Large cleanup to the tests They are now better organized and easier to extend, I hope. --- tests/integration/constants.py | 16 + tests/integration/test_examples/__init__.py | 0 .../test_crc32.py} | 4 +- .../{ => test_examples}/test_fib.py | 2 +- tests/integration/test_helper.py | 70 -- tests/integration/test_lang/__init__.py | 0 .../{ => test_lang}/test_builtins.py | 4 +- tests/integration/test_lang/test_bytes.py | 53 ++ tests/integration/test_lang/test_if.py | 71 ++ tests/integration/test_lang/test_interface.py | 27 + .../integration/test_lang/test_primitives.py | 324 +++++++++ .../test_lang/test_static_array.py | 87 +++ .../test_struct.py} | 105 +-- .../test_tuple.py} | 73 +- .../{ => test_lang}/test_type_checks.py | 0 tests/integration/test_runtime_checks.py | 31 - tests/integration/test_simple.py | 644 ------------------ tests/integration/test_stdlib/__init__.py | 0 .../test_alloc.py} | 4 +- 19 files changed, 668 insertions(+), 847 deletions(-) create mode 100644 tests/integration/constants.py create mode 100644 tests/integration/test_examples/__init__.py rename tests/integration/{test_examples.py => test_examples/test_crc32.py} (93%) rename tests/integration/{ => test_examples}/test_fib.py (94%) delete mode 100644 tests/integration/test_helper.py create mode 100644 tests/integration/test_lang/__init__.py rename tests/integration/{ => test_lang}/test_builtins.py (95%) create mode 100644 tests/integration/test_lang/test_bytes.py create mode 100644 tests/integration/test_lang/test_if.py create mode 100644 tests/integration/test_lang/test_interface.py create mode 100644 tests/integration/test_lang/test_primitives.py create mode 100644 tests/integration/test_lang/test_static_array.py rename tests/integration/{test_static_checking.py => test_lang/test_struct.py} (54%) rename tests/integration/{test_constants.py => test_lang/test_tuple.py} (56%) rename tests/integration/{ => test_lang}/test_type_checks.py (100%) delete mode 100644 tests/integration/test_runtime_checks.py delete mode 100644 tests/integration/test_simple.py create mode 100644 tests/integration/test_stdlib/__init__.py rename tests/integration/{test_stdlib_alloc.py => test_stdlib/test_alloc.py} (95%) diff --git a/tests/integration/constants.py b/tests/integration/constants.py new file mode 100644 index 0000000..07dacbe --- /dev/null +++ b/tests/integration/constants.py @@ -0,0 +1,16 @@ +""" +Constants for use in the tests +""" + +ALL_INT_TYPES = ['u8', 'u32', 'u64', 'i32', 'i64'] +COMPLETE_INT_TYPES = ['u32', 'u64', 'i32', 'i64'] + +ALL_FLOAT_TYPES = ['f32', 'f64'] +COMPLETE_FLOAT_TYPES = ALL_FLOAT_TYPES + +TYPE_MAP = { + **{x: int for x in ALL_INT_TYPES}, + **{x: float for x in ALL_FLOAT_TYPES}, +} + +COMPLETE_PRIMITIVE_TYPES = COMPLETE_INT_TYPES + COMPLETE_FLOAT_TYPES diff --git a/tests/integration/test_examples/__init__.py b/tests/integration/test_examples/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples/test_crc32.py similarity index 93% rename from tests/integration/test_examples.py rename to tests/integration/test_examples/test_crc32.py index b3b278d..d9f74bf 100644 --- a/tests/integration/test_examples.py +++ b/tests/integration/test_examples/test_crc32.py @@ -3,9 +3,9 @@ import struct import pytest -from .helpers import Suite +from ..helpers import Suite -@pytest.mark.integration_test +@pytest.mark.slow_integration_test def test_crc32(): # FIXME: Stub # crc = 0xFFFFFFFF diff --git a/tests/integration/test_fib.py b/tests/integration/test_examples/test_fib.py similarity index 94% rename from tests/integration/test_fib.py rename to tests/integration/test_examples/test_fib.py index 20e7e63..3f99a46 100644 --- a/tests/integration/test_fib.py +++ b/tests/integration/test_examples/test_fib.py @@ -1,6 +1,6 @@ import pytest -from .helpers import Suite +from ..helpers import Suite @pytest.mark.slow_integration_test def test_fib(): diff --git a/tests/integration/test_helper.py b/tests/integration/test_helper.py deleted file mode 100644 index cb44021..0000000 --- a/tests/integration/test_helper.py +++ /dev/null @@ -1,70 +0,0 @@ -import io - -import pytest - -from pywasm import binary -from pywasm import Runtime - -from wasmer import wat2wasm - -def run(code_wat): - code_wasm = wat2wasm(code_wat) - module = binary.Module.from_reader(io.BytesIO(code_wasm)) - - runtime = Runtime(module, {}, {}) - - out_put = runtime.exec('testEntry', []) - return (runtime, out_put) - -@pytest.mark.parametrize('size,offset,exp_out_put', [ - ('32', 0, 0x3020100), - ('32', 1, 0x4030201), - ('64', 0, 0x706050403020100), - ('64', 2, 0x908070605040302), -]) -def test_i32_64_load(size, offset, exp_out_put): - code_wat = f""" - (module - (memory 1) - (data (memory 0) (i32.const 0) "\\00\\01\\02\\03\\04\\05\\06\\07\\08\\09\\10") - - (func (export "testEntry") (result i{size}) - i32.const {offset} - i{size}.load - return )) -""" - - (_, out_put) = run(code_wat) - assert exp_out_put == out_put - -def test_load_then_store(): - code_wat = """ - (module - (memory 1) - (data (memory 0) (i32.const 0) "\\04\\00\\00\\00") - - (func (export "testEntry") (result i32) (local $my_memory_value i32) - ;; Load i32 from address 0 - i32.const 0 - i32.load - - ;; Add 8 to the loaded value - i32.const 8 - i32.add - - local.set $my_memory_value - - ;; Store back to the memory - i32.const 0 - local.get $my_memory_value - i32.store - - ;; Return something - i32.const 9 - return )) -""" - (runtime, out_put) = run(code_wat) - - assert 9 == out_put - - assert (b'\x0c'+ b'\00' * 23) == runtime.store.mems[0].data[:24] diff --git a/tests/integration/test_lang/__init__.py b/tests/integration/test_lang/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_builtins.py b/tests/integration/test_lang/test_builtins.py similarity index 95% rename from tests/integration/test_builtins.py rename to tests/integration/test_lang/test_builtins.py index 4b84197..2b06afd 100644 --- a/tests/integration/test_builtins.py +++ b/tests/integration/test_lang/test_builtins.py @@ -2,8 +2,8 @@ import sys import pytest -from .helpers import Suite, write_header -from .runners import RunnerPywasm +from ..helpers import Suite, write_header +from ..runners import RunnerPywasm def setup_interpreter(phash_code: str) -> RunnerPywasm: runner = RunnerPywasm(phash_code) diff --git a/tests/integration/test_lang/test_bytes.py b/tests/integration/test_lang/test_bytes.py new file mode 100644 index 0000000..45fc2a1 --- /dev/null +++ b/tests/integration/test_lang/test_bytes.py @@ -0,0 +1,53 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.integration_test +def test_bytes_address(): + code_py = """ +@exported +def testEntry(f: bytes) -> bytes: + return f +""" + + result = Suite(code_py).run_code(b'This is a test') + + # THIS DEPENDS ON THE ALLOCATOR + # A different allocator will return a different value + assert 20 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_length(): + code_py = """ +@exported +def testEntry(f: bytes) -> i32: + return len(f) +""" + + result = Suite(code_py).run_code(b'This is another test') + + assert 20 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_index(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return f[8] +""" + + result = Suite(code_py).run_code(b'This is another test') + + assert 0x61 == result.returned_value + +@pytest.mark.integration_test +def test_bytes_index_out_of_bounds(): + code_py = """ +@exported +def testEntry(f: bytes) -> u8: + return f[50] +""" + + result = Suite(code_py).run_code(b'Short', b'Long' * 100) + + assert 0 == result.returned_value diff --git a/tests/integration/test_lang/test_if.py b/tests/integration/test_lang/test_if.py new file mode 100644 index 0000000..5d77eb9 --- /dev/null +++ b/tests/integration/test_lang/test_if.py @@ -0,0 +1,71 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.integration_test +@pytest.mark.parametrize('inp', [9, 10, 11, 12]) +def test_if_simple(inp): + code_py = """ +@exported +def testEntry(a: i32) -> i32: + if a > 10: + return 15 + + return 3 +""" + exp_result = 15 if inp > 10 else 3 + + suite = Suite(code_py) + + result = suite.run_code(inp) + assert exp_result == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.skip('Such a return is not how things should be') +def test_if_complex(): + code_py = """ +@exported +def testEntry(a: i32) -> i32: + if a > 10: + return 10 + elif a > 0: + return a + else: + return 0 + + return -1 # Required due to function type +""" + + suite = Suite(code_py) + + assert 10 == suite.run_code(20).returned_value + assert 10 == suite.run_code(10).returned_value + + assert 8 == suite.run_code(8).returned_value + + assert 0 == suite.run_code(0).returned_value + assert 0 == suite.run_code(-1).returned_value + +@pytest.mark.integration_test +def test_if_nested(): + code_py = """ +@exported +def testEntry(a: i32, b: i32) -> i32: + if a > 11: + if b > 11: + return 3 + + return 2 + + if b > 11: + return 1 + + return 0 +""" + + suite = Suite(code_py) + + assert 3 == suite.run_code(20, 20).returned_value + assert 2 == suite.run_code(20, 10).returned_value + assert 1 == suite.run_code(10, 20).returned_value + assert 0 == suite.run_code(10, 10).returned_value diff --git a/tests/integration/test_lang/test_interface.py b/tests/integration/test_lang/test_interface.py new file mode 100644 index 0000000..cbacd73 --- /dev/null +++ b/tests/integration/test_lang/test_interface.py @@ -0,0 +1,27 @@ +import pytest + +from ..helpers import Suite + +@pytest.mark.integration_test +def test_imported(): + code_py = """ +@imported +def helper(mul: i32) -> i32: + pass + +@exported +def testEntry() -> i32: + return helper(2) +""" + + def helper(mul: int) -> int: + return 4238 * mul + + result = Suite(code_py).run_code( + runtime='wasmer', + imports={ + 'helper': helper, + } + ) + + assert 8476 == result.returned_value diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py new file mode 100644 index 0000000..d0eeb8f --- /dev/null +++ b/tests/integration/test_lang/test_primitives.py @@ -0,0 +1,324 @@ +import pytest + +from ..helpers import Suite +from ..constants import ALL_INT_TYPES, ALL_FLOAT_TYPES, COMPLETE_INT_TYPES, TYPE_MAP + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_expr_constant_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 13 +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_expr_constant_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 32.125 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_INT_TYPES) +def test_module_constant_int(type_): + code_py = f""" +CONSTANT: {type_} = 13 + +@exported +def testEntry() -> {type_}: + return CONSTANT +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_module_constant_float(type_): + code_py = f""" +CONSTANT: {type_} = 32.125 + +@exported +def testEntry() -> {type_}: + return CONSTANT +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation +def test_logical_left_shift(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 << 3 +""" + + result = Suite(code_py).run_code() + + assert 80 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_logical_right_shift(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 >> 3 +""" + + result = Suite(code_py).run_code() + + assert 1 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_or(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 | 3 +""" + + result = Suite(code_py).run_code() + + assert 11 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_xor(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 ^ 3 +""" + + result = Suite(code_py).run_code() + + assert 9 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) +def test_bitwise_and(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 & 3 +""" + + result = Suite(code_py).run_code() + + assert 2 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_addition_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 + 3 +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_addition_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 32.0 + 0.125 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_subtraction_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 10 - 3 +""" + + result = Suite(code_py).run_code() + + assert 7 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_subtraction_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return 100.0 - 67.875 +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +# TODO: Multiplication +# TODO: Division + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ['f32', 'f64']) +def test_buildins_sqrt(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return sqrt(25.0) +""" + + result = Suite(code_py).run_code() + + assert 5 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', TYPE_MAP.keys()) +def test_function_argument(type_): + code_py = f""" +@exported +def testEntry(a: {type_}) -> {type_}: + return a +""" + + result = Suite(code_py).run_code(125) + + assert 125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.skip('TODO') +def test_explicit_positive_number(): + code_py = """ +@exported +def testEntry() -> i32: + return +523 +""" + + result = Suite(code_py).run_code() + + assert 523 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.skip('TODO') +def test_explicit_negative_number(): + code_py = """ +@exported +def testEntry() -> i32: + return -19 +""" + + result = Suite(code_py).run_code() + + assert -19 == result.returned_value + +@pytest.mark.integration_test +def test_call_no_args(): + code_py = """ +def helper() -> i32: + return 19 + +@exported +def testEntry() -> i32: + return helper() +""" + + result = Suite(code_py).run_code() + + assert 19 == result.returned_value + +@pytest.mark.integration_test +def test_call_pre_defined(): + code_py = """ +def helper(left: i32, right: i32) -> i32: + return left + right + +@exported +def testEntry() -> i32: + return helper(10, 3) +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +def test_call_post_defined(): + code_py = """ +@exported +def testEntry() -> i32: + return helper(10, 3) + +def helper(left: i32, right: i32) -> i32: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 7 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_call_with_expression_int(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper(10 + 20, 3 + 5) + +def helper(left: {type_}, right: {type_}) -> {type_}: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 22 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_call_with_expression_float(type_): + code_py = f""" +@exported +def testEntry() -> {type_}: + return helper(10.078125 + 90.046875, 63.0 + 5.0) + +def helper(left: {type_}, right: {type_}) -> {type_}: + return left - right +""" + + result = Suite(code_py).run_code() + + assert 32.125 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py new file mode 100644 index 0000000..ced9277 --- /dev/null +++ b/tests/integration/test_lang/test_static_array.py @@ -0,0 +1,87 @@ +import pytest + +from phasm.exceptions import StaticError + +from ..constants import COMPLETE_PRIMITIVE_TYPES, TYPE_MAP +from ..helpers import Suite + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +def test_static_array_module_constant(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT) + +def helper(array: {type_}[3]) -> {type_}: + return array[0] + array[1] + array[2] +""" + + result = Suite(code_py).run_code() + + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +def test_static_array_indexed(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT, 0, 1, 2) + +def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}: + return array[i0] + array[i1] + array[i2] +""" + + result = Suite(code_py).run_code() + + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_static_array_constant_too_few_values(): + code_py = """ +CONSTANT: u8[3] = (24, 57, ) +""" + + with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): + phasm_parse(code_py) + +@pytest.mark.integration_test +def test_static_array_constant_too_many_values(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 1, 1, ) +""" + + with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): + phasm_parse(code_py) + +@pytest.mark.integration_test +def test_static_array_constant_type_mismatch(): + code_py = """ +CONSTANT: u8[3] = (24, 4000, 1, ) +""" + + with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): + phasm_parse(code_py) + +@pytest.mark.integration_test +def test_static_array_index_out_of_bounds(): + code_py = """ +CONSTANT0: u32[3] = (24, 57, 80, ) + +CONSTANT1: u32[16] = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ) + +@exported +def testEntry() -> u32: + return CONSTANT0[16] +""" + + result = Suite(code_py).run_code() + + assert 0 == result.returned_value diff --git a/tests/integration/test_static_checking.py b/tests/integration/test_lang/test_struct.py similarity index 54% rename from tests/integration/test_static_checking.py rename to tests/integration/test_lang/test_struct.py index 1544537..04846e7 100644 --- a/tests/integration/test_static_checking.py +++ b/tests/integration/test_lang/test_struct.py @@ -1,18 +1,65 @@ import pytest -from phasm.parser import phasm_parse -from phasm.exceptions import StaticError +from ..helpers import Suite @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) -def test_type_mismatch_function_argument(type_): +@pytest.mark.parametrize('type_', ('i32', 'f64', )) +def test_struct_0(type_): code_py = f""" -def helper(a: {type_}) -> (i32, i32, ): - return a +class CheckedValue: + value: {type_} + +@exported +def testEntry() -> {type_}: + return helper(CheckedValue(23)) + +def helper(cv: CheckedValue) -> {type_}: + return cv.value """ - with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), a is actually {type_}'): - phasm_parse(code_py) + result = Suite(code_py).run_code() + + assert 23 == result.returned_value + +@pytest.mark.integration_test +def test_struct_1(): + code_py = """ +class Rectangle: + height: i32 + width: i32 + border: i32 + +@exported +def testEntry() -> i32: + return helper(Rectangle(100, 150, 2)) + +def helper(shape: Rectangle) -> i32: + return shape.height + shape.width + shape.border +""" + + result = Suite(code_py).run_code() + + assert 252 == result.returned_value + +@pytest.mark.integration_test +def test_struct_2(): + code_py = """ +class Rectangle: + height: i32 + width: i32 + border: i32 + +@exported +def testEntry() -> i32: + return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3)) + +def helper(shape1: Rectangle, shape2: Rectangle) -> i32: + return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border +""" + + result = Suite(code_py).run_code() + + assert 545 == result.returned_value @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @@ -39,21 +86,6 @@ def testEntry(arg: ({type_}, )) -> (i32, i32, ): with pytest.raises(StaticError, match=f'Static error on line 3: Expected \\(i32, i32, \\), arg\\[0\\] is actually {type_}'): phasm_parse(code_py) -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) -def test_type_mismatch_function_result(type_): - code_py = f""" -def helper() -> {type_}: - return 1 - -@exported -def testEntry() -> (i32, i32, ): - return helper() -""" - - with pytest.raises(StaticError, match=f'Static error on line 7: Expected \\(i32, i32, \\), helper actually returns {type_}'): - phasm_parse(code_py) - @pytest.mark.integration_test def test_tuple_constant_too_few_values(): code_py = """ @@ -80,30 +112,3 @@ CONSTANT: (u32, u8, u8, ) = (24, 4000, 1, ) with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_too_few_values(): - code_py = """ -CONSTANT: u8[3] = (24, 57, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_too_many_values(): - code_py = """ -CONSTANT: u8[3] = (24, 57, 1, 1, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): - phasm_parse(code_py) - -@pytest.mark.integration_test -def test_static_array_constant_type_mismatch(): - code_py = """ -CONSTANT: u8[3] = (24, 4000, 1, ) -""" - - with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): - phasm_parse(code_py) diff --git a/tests/integration/test_constants.py b/tests/integration/test_lang/test_tuple.py similarity index 56% rename from tests/integration/test_constants.py rename to tests/integration/test_lang/test_tuple.py index accf9e2..5c5321e 100644 --- a/tests/integration/test_constants.py +++ b/tests/integration/test_lang/test_tuple.py @@ -1,34 +1,7 @@ import pytest -from .helpers import Suite - -@pytest.mark.integration_test -def test_i32_asis(): - code_py = """ -CONSTANT: i32 = 13 - -@exported -def testEntry() -> i32: - return CONSTANT -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - -@pytest.mark.integration_test -def test_i32_binop(): - code_py = """ -CONSTANT: i32 = 13 - -@exported -def testEntry() -> i32: - return CONSTANT * 5 -""" - - result = Suite(code_py).run_code() - - assert 65 == result.returned_value +from ..constants import COMPLETE_PRIMITIVE_TYPES, TYPE_MAP +from ..helpers import Suite @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) @@ -66,36 +39,46 @@ def helper(vector: (u8, u8, u32, u32, u64, u64, )) -> u32: assert 3333 == result.returned_value @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64', ]) -def test_static_array_1(type_): +@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +def test_tuple_simple_constructor(type_): code_py = f""" -CONSTANT: {type_}[1] = (65, ) - @exported def testEntry() -> {type_}: - return helper(CONSTANT) + return helper((24, 57, 80, )) -def helper(vector: {type_}[1]) -> {type_}: - return vector[0] +def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}: + return vector[0] + vector[1] + vector[2] """ result = Suite(code_py).run_code() - assert 65 == result.returned_value + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -def test_static_array_6(): +def test_tuple_float(): code_py = """ -CONSTANT: u32[6] = (11, 22, 3333, 4444, 555555, 666666, ) - @exported -def testEntry() -> u32: - return helper(CONSTANT) +def testEntry() -> f32: + return helper((1.0, 2.0, 3.0, )) -def helper(vector: u32[6]) -> u32: - return vector[2] +def helper(v: (f32, f32, f32, )) -> f32: + return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]) """ result = Suite(code_py).run_code() - assert 3333 == result.returned_value + assert 3.74 < result.returned_value < 3.75 + +@pytest.mark.integration_test +@pytest.mark.skip('SIMD support is but a dream') +def test_tuple_i32x4(): + code_py = """ +@exported +def testEntry() -> i32x4: + return (51, 153, 204, 0, ) +""" + + result = Suite(code_py).run_code() + + assert (1, 2, 3, 0) == result.returned_value diff --git a/tests/integration/test_type_checks.py b/tests/integration/test_lang/test_type_checks.py similarity index 100% rename from tests/integration/test_type_checks.py rename to tests/integration/test_lang/test_type_checks.py diff --git a/tests/integration/test_runtime_checks.py b/tests/integration/test_runtime_checks.py deleted file mode 100644 index 97d6542..0000000 --- a/tests/integration/test_runtime_checks.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest - -from .helpers import Suite - -@pytest.mark.integration_test -def test_bytes_index_out_of_bounds(): - code_py = """ -@exported -def testEntry(f: bytes) -> u8: - return f[50] -""" - - result = Suite(code_py).run_code(b'Short', b'Long' * 100) - - assert 0 == result.returned_value - -@pytest.mark.integration_test -def test_static_array_index_out_of_bounds(): - code_py = """ -CONSTANT0: u32[3] = (24, 57, 80, ) - -CONSTANT1: u32[16] = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, ) - -@exported -def testEntry() -> u32: - return CONSTANT0[16] -""" - - result = Suite(code_py).run_code() - - assert 0 == result.returned_value diff --git a/tests/integration/test_simple.py b/tests/integration/test_simple.py deleted file mode 100644 index fdad1e4..0000000 --- a/tests/integration/test_simple.py +++ /dev/null @@ -1,644 +0,0 @@ -import pytest - -from .helpers import Suite - -ALL_INT_TYPES = ['u8', 'u32', 'u64', 'i32', 'i64'] -COMLETE_INT_TYPES = ['u32', 'u64', 'i32', 'i64'] -ALL_FLOAT_TYPES = ['f32', 'f64'] - -TYPE_MAP = { - **{x: int for x in ALL_INT_TYPES}, - **{x: float for x in ALL_FLOAT_TYPES}, -} - -COMPLETE_SIMPLE_TYPES = [ - 'u32', 'u64', - 'i32', 'i64', - 'f32', 'f64', -] - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ALL_INT_TYPES) -def test_return_int(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 13 -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) -def test_return_float(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 32.125 -""" - - result = Suite(code_py).run_code() - - assert 32.125 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMLETE_INT_TYPES) -def test_addition_int(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 + 3 -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) -def test_addition_float(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 32.0 + 0.125 -""" - - result = Suite(code_py).run_code() - - assert 32.125 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMLETE_INT_TYPES) -def test_subtraction_int(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 - 3 -""" - - result = Suite(code_py).run_code() - - assert 7 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) -def test_subtraction(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 100.0 - 67.875 -""" - - result = Suite(code_py).run_code() - - assert 32.125 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation -def test_logical_left_shift(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 << 3 -""" - - result = Suite(code_py).run_code() - - assert 80 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_logical_right_shift(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 >> 3 -""" - - result = Suite(code_py).run_code() - - assert 1 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_or(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 | 3 -""" - - result = Suite(code_py).run_code() - - assert 11 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_xor(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 ^ 3 -""" - - result = Suite(code_py).run_code() - - assert 9 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_bitwise_and(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return 10 & 3 -""" - - result = Suite(code_py).run_code() - - assert 2 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['f32', 'f64']) -def test_buildins_sqrt(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return sqrt(25.0) -""" - - result = Suite(code_py).run_code() - - assert 5 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_arg(type_): - code_py = f""" -@exported -def testEntry(a: {type_}) -> {type_}: - return a -""" - - result = Suite(code_py).run_code(125) - - assert 125 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_i32_to_i64(): - code_py = """ -@exported -def testEntry(a: i32) -> i64: - return a -""" - - result = Suite(code_py).run_code(125) - - assert 125 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_i32_plus_i64(): - code_py = """ -@exported -def testEntry(a: i32, b: i64) -> i64: - return a + b -""" - - result = Suite(code_py).run_code(125, 100) - - assert 225 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_f32_to_f64(): - code_py = """ -@exported -def testEntry(a: f32) -> f64: - return a -""" - - result = Suite(code_py).run_code(125.5) - - assert 125.5 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Do we want it to work like this?') -def test_f32_plus_f64(): - code_py = """ -@exported -def testEntry(a: f32, b: f64) -> f64: - return a + b -""" - - result = Suite(code_py).run_code(125.5, 100.25) - - assert 225.75 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('TODO') -def test_uadd(): - code_py = """ -@exported -def testEntry() -> i32: - return +523 -""" - - result = Suite(code_py).run_code() - - assert 523 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('TODO') -def test_usub(): - code_py = """ -@exported -def testEntry() -> i32: - return -19 -""" - - result = Suite(code_py).run_code() - - assert -19 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('inp', [9, 10, 11, 12]) -def test_if_simple(inp): - code_py = """ -@exported -def testEntry(a: i32) -> i32: - if a > 10: - return 15 - - return 3 -""" - exp_result = 15 if inp > 10 else 3 - - suite = Suite(code_py) - - result = suite.run_code(inp) - assert exp_result == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('Such a return is not how things should be') -def test_if_complex(): - code_py = """ -@exported -def testEntry(a: i32) -> i32: - if a > 10: - return 10 - elif a > 0: - return a - else: - return 0 - - return -1 # Required due to function type -""" - - suite = Suite(code_py) - - assert 10 == suite.run_code(20).returned_value - assert 10 == suite.run_code(10).returned_value - - assert 8 == suite.run_code(8).returned_value - - assert 0 == suite.run_code(0).returned_value - assert 0 == suite.run_code(-1).returned_value - -@pytest.mark.integration_test -def test_if_nested(): - code_py = """ -@exported -def testEntry(a: i32, b: i32) -> i32: - if a > 11: - if b > 11: - return 3 - - return 2 - - if b > 11: - return 1 - - return 0 -""" - - suite = Suite(code_py) - - assert 3 == suite.run_code(20, 20).returned_value - assert 2 == suite.run_code(20, 10).returned_value - assert 1 == suite.run_code(10, 20).returned_value - assert 0 == suite.run_code(10, 10).returned_value - -@pytest.mark.integration_test -def test_call_no_args(): - code_py = """ -def helper() -> i32: - return 19 - -@exported -def testEntry() -> i32: - return helper() -""" - - result = Suite(code_py).run_code() - - assert 19 == result.returned_value - -@pytest.mark.integration_test -def test_call_pre_defined(): - code_py = """ -def helper(left: i32, right: i32) -> i32: - return left + right - -@exported -def testEntry() -> i32: - return helper(10, 3) -""" - - result = Suite(code_py).run_code() - - assert 13 == result.returned_value - -@pytest.mark.integration_test -def test_call_post_defined(): - code_py = """ -@exported -def testEntry() -> i32: - return helper(10, 3) - -def helper(left: i32, right: i32) -> i32: - return left - right -""" - - result = Suite(code_py).run_code() - - assert 7 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMLETE_INT_TYPES) -def test_call_with_expression_int(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return helper(10 + 20, 3 + 5) - -def helper(left: {type_}, right: {type_}) -> {type_}: - return left - right -""" - - result = Suite(code_py).run_code() - - assert 22 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) -def test_call_with_expression_float(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return helper(10.078125 + 90.046875, 63.0 + 5.0) - -def helper(left: {type_}, right: {type_}) -> {type_}: - return left - right -""" - - result = Suite(code_py).run_code() - - assert 32.125 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.skip('Not yet implemented') -def test_assign(): - code_py = """ - -@exported -def testEntry() -> i32: - a: i32 = 8947 - return a -""" - - result = Suite(code_py).run_code() - - assert 8947 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', TYPE_MAP.keys()) -def test_struct_0(type_): - code_py = f""" -class CheckedValue: - value: {type_} - -@exported -def testEntry() -> {type_}: - return helper(CheckedValue(23)) - -def helper(cv: CheckedValue) -> {type_}: - return cv.value -""" - - result = Suite(code_py).run_code() - - assert 23 == result.returned_value - -@pytest.mark.integration_test -def test_struct_1(): - code_py = """ -class Rectangle: - height: i32 - width: i32 - border: i32 - -@exported -def testEntry() -> i32: - return helper(Rectangle(100, 150, 2)) - -def helper(shape: Rectangle) -> i32: - return shape.height + shape.width + shape.border -""" - - result = Suite(code_py).run_code() - - assert 252 == result.returned_value - -@pytest.mark.integration_test -def test_struct_2(): - code_py = """ -class Rectangle: - height: i32 - width: i32 - border: i32 - -@exported -def testEntry() -> i32: - return helper(Rectangle(100, 150, 2), Rectangle(200, 90, 3)) - -def helper(shape1: Rectangle, shape2: Rectangle) -> i32: - return shape1.height + shape1.width + shape1.border + shape2.height + shape2.width + shape2.border -""" - - result = Suite(code_py).run_code() - - assert 545 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_tuple_simple_constructor(type_): - code_py = f""" -@exported -def testEntry() -> {type_}: - return helper((24, 57, 80, )) - -def helper(vector: ({type_}, {type_}, {type_}, )) -> {type_}: - return vector[0] + vector[1] + vector[2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -def test_tuple_float(): - code_py = """ -@exported -def testEntry() -> f32: - return helper((1.0, 2.0, 3.0, )) - -def helper(v: (f32, f32, f32, )) -> f32: - return sqrt(v[0] * v[0] + v[1] * v[1] + v[2] * v[2]) -""" - - result = Suite(code_py).run_code() - - assert 3.74 < result.returned_value < 3.75 - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_static_array_module_constant(type_): - code_py = f""" -CONSTANT: {type_}[3] = (24, 57, 80, ) - -@exported -def testEntry() -> {type_}: - return helper(CONSTANT) - -def helper(array: {type_}[3]) -> {type_}: - return array[0] + array[1] + array[2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_SIMPLE_TYPES) -def test_static_array_indexed(type_): - code_py = f""" -CONSTANT: {type_}[3] = (24, 57, 80, ) - -@exported -def testEntry() -> {type_}: - return helper(CONSTANT, 0, 1, 2) - -def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}: - return array[i0] + array[i1] + array[i2] -""" - - result = Suite(code_py).run_code() - - assert 161 == result.returned_value - assert TYPE_MAP[type_] == type(result.returned_value) - -@pytest.mark.integration_test -def test_bytes_address(): - code_py = """ -@exported -def testEntry(f: bytes) -> bytes: - return f -""" - - result = Suite(code_py).run_code(b'This is a test') - - # THIS DEPENDS ON THE ALLOCATOR - # A different allocator will return a different value - assert 20 == result.returned_value - -@pytest.mark.integration_test -def test_bytes_length(): - code_py = """ -@exported -def testEntry(f: bytes) -> i32: - return len(f) -""" - - result = Suite(code_py).run_code(b'This is another test') - - assert 20 == result.returned_value - -@pytest.mark.integration_test -def test_bytes_index(): - code_py = """ -@exported -def testEntry(f: bytes) -> u8: - return f[8] -""" - - result = Suite(code_py).run_code(b'This is another test') - - assert 0x61 == result.returned_value - -@pytest.mark.integration_test -@pytest.mark.skip('SIMD support is but a dream') -def test_tuple_i32x4(): - code_py = """ -@exported -def testEntry() -> i32x4: - return (51, 153, 204, 0, ) -""" - - result = Suite(code_py).run_code() - - assert (1, 2, 3, 0) == result.returned_value - -@pytest.mark.integration_test -def test_imported(): - code_py = """ -@imported -def helper(mul: i32) -> i32: - pass - -@exported -def testEntry() -> i32: - return helper(2) -""" - - def helper(mul: int) -> int: - return 4238 * mul - - result = Suite(code_py).run_code( - runtime='wasmer', - imports={ - 'helper': helper, - } - ) - - assert 8476 == result.returned_value diff --git a/tests/integration/test_stdlib/__init__.py b/tests/integration/test_stdlib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_stdlib_alloc.py b/tests/integration/test_stdlib/test_alloc.py similarity index 95% rename from tests/integration/test_stdlib_alloc.py rename to tests/integration/test_stdlib/test_alloc.py index da8ccea..96d1fd6 100644 --- a/tests/integration/test_stdlib_alloc.py +++ b/tests/integration/test_stdlib/test_alloc.py @@ -2,8 +2,8 @@ import sys import pytest -from .helpers import write_header -from .runners import RunnerPywasm3 as Runner +from ..helpers import write_header +from ..runners import RunnerPywasm3 as Runner def setup_interpreter(phash_code: str) -> Runner: runner = Runner(phash_code) -- 2.49.0 From 299551db1bd843a3d1b44f421b7117a47f642bd0 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 11:49:10 +0200 Subject: [PATCH 12/18] All primitive tests work again --- phasm/compiler.py | 2 - phasm/typer.py | 28 ++++++--- phasm/typing.py | 58 ++++++++++++------- .../integration/test_lang/test_primitives.py | 16 ++++- 4 files changed, 70 insertions(+), 34 deletions(-) diff --git a/phasm/compiler.py b/phasm/compiler.py index 7f6d53f..18a548a 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -72,8 +72,6 @@ U8_OPERATOR_MAP = { # Under the hood, this is an i32 # Implementing Right Shift XOR, OR, AND is fine since the 3 remaining # bytes stay zero after this operation - # Since it's unsigned an unsigned value, Logical or Arithmetic shift right - # are the same operation '>>': 'shr_u', '^': 'xor', '|': 'or', diff --git a/phasm/typer.py b/phasm/typer.py index 3041ec3..2b17f63 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -77,17 +77,29 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return right if isinstance(inp, ourlang.BinaryOp): - # TODO: Simplified version - if inp.operator not in ('+', '-', '*', '|', '&', '^'): - raise NotImplementedError(expression, inp, inp.operator) + if inp.operator in ('+', '-', '*', '|', '&', '^'): + left = expression(ctx, inp.left) + right = expression(ctx, inp.right) + ctx.unify(left, right) - left = expression(ctx, inp.left) - right = expression(ctx, inp.right) - ctx.unify(left, right) + inp.type_var = left + return left - inp.type_var = left + if inp.operator in ('<<', '>>', ): + inp.type_var = ctx.new_var() + inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, ))) + inp.type_var.add_constraint(TypeConstraintSigned(False)) - return left + left = expression(ctx, inp.left) + right = expression(ctx, inp.right) + ctx.unify(left, right) + + ctx.unify(inp.type_var, left) + + return left + + raise NotImplementedError(expression, inp, inp.operator) if isinstance(inp, ourlang.FunctionCall): assert inp.function.returns_type_var is not None diff --git a/phasm/typing.py b/phasm/typing.py index de92a7b..0bebf8d 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -1,7 +1,7 @@ """ The phasm type system """ -from typing import Dict, Optional, List, Type +from typing import Dict, Iterable, Optional, List, Set, Type import enum @@ -217,35 +217,41 @@ class TypeConstraintBitWidth(TypeConstraintBase): """ Contraint on how many bits an expression has or can possibly have """ - __slots__ = ('minb', 'maxb', ) + __slots__ = ('oneof', ) - minb: int - maxb: int + oneof: Set[int] - def __init__(self, *, minb: int = 1, maxb: int = 64) -> None: - assert minb is not None or maxb is not None - assert maxb <= 64 # For now, support up to 64 bits values + def __init__(self, *, oneof: Optional[Iterable[int]] = None, minb: Optional[int] = None, maxb: Optional[int] = None) -> None: + # For now, support up to 64 bits values + self.oneof = set(oneof) if oneof is not None else set(range(1, 65)) - self.minb = minb - self.maxb = maxb + if minb is not None: + self.oneof = { + x + for x in self.oneof + if minb <= x + } + + if maxb is not None: + self.oneof = { + x + for x in self.oneof + if x <= maxb + } def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': if not isinstance(other, TypeConstraintBitWidth): raise Exception('Invalid comparison') - if self.minb > other.maxb: - raise TypingNarrowProtoError('Min bitwidth exceeds other max bitwidth') + new_oneof = self.oneof & other.oneof - if other.minb > self.maxb: - raise TypingNarrowProtoError('Other min bitwidth exceeds max bitwidth') + if not new_oneof: + raise TypingNarrowProtoError('Memory width cannot be resolved') - return TypeConstraintBitWidth( - minb=max(self.minb, other.minb), - maxb=min(self.maxb, other.maxb), - ) + return TypeConstraintBitWidth(oneof=new_oneof) def __repr__(self) -> str: - return f'BitWidth={self.minb}..{self.maxb}' + return 'BitWidth=oneof(' + ','.join(map(str, sorted(self.oneof))) + ')' class TypeVar: """ @@ -380,11 +386,15 @@ def simplify(inp: TypeVar) -> Optional[str]: assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint assert isinstance(tc_sign, TypeConstraintSigned) # type hint - if tc_sign.signed is None or tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (8, 32, 64): + if tc_sign.signed is None or len(tc_bits.oneof) != 1: + return None + + bitwidth = next(iter(tc_bits.oneof)) + if bitwidth not in (8, 32, 64): return None base = 'i' if tc_sign.signed else 'u' - return f'{base}{tc_bits.minb}' + return f'{base}{bitwidth}' if primitive is TypeConstraintPrimitive.Primitive.FLOAT: if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint @@ -392,10 +402,14 @@ def simplify(inp: TypeVar) -> Optional[str]: assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint - if tc_bits.minb != tc_bits.maxb or tc_bits.minb not in (32, 64): + if len(tc_bits.oneof) != 1: return None - return f'f{tc_bits.minb}' + bitwidth = next(iter(tc_bits.oneof)) + if bitwidth not in (32, 64): + return None + + return f'f{bitwidth}' return None diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py index d0eeb8f..d63d4ac 100644 --- a/tests/integration/test_lang/test_primitives.py +++ b/tests/integration/test_lang/test_primitives.py @@ -76,8 +76,8 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) -def test_logical_right_shift(type_): +@pytest.mark.parametrize('type_', ['u32', 'u64']) +def test_logical_right_shift_left_bit_zero(type_): code_py = f""" @exported def testEntry() -> {type_}: @@ -89,6 +89,18 @@ def testEntry() -> {type_}: assert 1 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +def test_logical_right_shift_left_bit_one(): + code_py = """ +@exported +def testEntry() -> u32: + return 4294967295 >> 16 +""" + + result = Suite(code_py).run_code() + + assert 0xFFFF == result.returned_value + @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['u8', 'u32', 'u64']) def test_bitwise_or(type_): -- 2.49.0 From 0097ce782d6f4bb00c22ea202002d72b76122c5e Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 12:15:03 +0200 Subject: [PATCH 13/18] First work on restoring StaticArray Removed the separate ModuleConstantRef since you can tell by the variable property of VariableReference. We'll also add local variables there later on. --- phasm/codestyle.py | 27 ++-- phasm/compiler.py | 91 +++++++------ phasm/ourlang.py | 73 +---------- phasm/parser.py | 124 ++++++++++-------- phasm/typer.py | 9 +- .../test_lang/test_static_array.py | 28 +++- 6 files changed, 163 insertions(+), 189 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 04ee043..500b878 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -95,26 +95,23 @@ def expression(inp: ourlang.Expression) -> str: # return f'({args}, )' return f'{inp.function.name}({args})' - - if isinstance(inp, ourlang.AccessBytesIndex): - return f'{expression(inp.varref)}[{expression(inp.index)}]' - - if isinstance(inp, ourlang.AccessStructMember): - return f'{expression(inp.varref)}.{inp.member.name}' - - if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): - if isinstance(inp.member, ourlang.Expression): - return f'{expression(inp.varref)}[{expression(inp.member)}]' - - return f'{expression(inp.varref)}[{inp.member.idx}]' + # + # if isinstance(inp, ourlang.AccessBytesIndex): + # return f'{expression(inp.varref)}[{expression(inp.index)}]' + # + # if isinstance(inp, ourlang.AccessStructMember): + # return f'{expression(inp.varref)}.{inp.member.name}' + # + # if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): + # if isinstance(inp.member, ourlang.Expression): + # return f'{expression(inp.varref)}[{expression(inp.member)}]' + # + # return f'{expression(inp.varref)}[{inp.member.idx}]' if isinstance(inp, ourlang.Fold): fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' - if isinstance(inp, ourlang.ModuleConstantReference): - return inp.definition.name - raise NotImplementedError(expression, inp) def statement(inp: ourlang.Statement) -> Statements: diff --git a/phasm/compiler.py b/phasm/compiler.py index 18a548a..826f21e 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -156,8 +156,39 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: raise NotImplementedError(f'Constants with type {stp}') if isinstance(inp, ourlang.VariableReference): - wgn.add_statement('local.get', '${}'.format(inp.variable.name)) - return + if isinstance(inp.variable, ourlang.FunctionParam): + wgn.add_statement('local.get', '${}'.format(inp.variable.name)) + return + + if isinstance(inp.variable, ourlang.ModuleConstantDef): + # FIXME: Tuple / Static Array broken after new type system + # if isinstance(inp.type, typing.TypeTuple): + # assert isinstance(inp.definition.constant, ourlang.ConstantTuple) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return + # + # if isinstance(inp.type, typing.TypeStaticArray): + # assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) + # assert inp.definition.data_block is not None, 'Combined values are memory stored' + # assert inp.definition.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.definition.data_block.address) + # return + + assert inp.variable.data_block is None, 'Primitives are not memory stored' + + assert inp.variable.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.variable.type_var) + if mtyp is None: + # In the future might extend this by having structs or tuples + # as members of struct or tuples + raise NotImplementedError(expression, inp, inp.type_var) + + expression(wgn, inp.variable.constant) + return + + raise NotImplementedError(expression, inp.variable) if isinstance(inp, ourlang.BinaryOp): expression(wgn, inp.left) @@ -301,34 +332,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: expression_fold(wgn, inp) return - if isinstance(inp, ourlang.ModuleConstantReference): - # FIXME: Tuple / Static Array broken after new type system - # if isinstance(inp.type, typing.TypeTuple): - # assert isinstance(inp.definition.constant, ourlang.ConstantTuple) - # assert inp.definition.data_block is not None, 'Combined values are memory stored' - # assert inp.definition.data_block.address is not None, 'Value not allocated' - # wgn.i32.const(inp.definition.data_block.address) - # return - # - # if isinstance(inp.type, typing.TypeStaticArray): - # assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) - # assert inp.definition.data_block is not None, 'Combined values are memory stored' - # assert inp.definition.data_block.address is not None, 'Value not allocated' - # wgn.i32.const(inp.definition.data_block.address) - # return - - assert inp.definition.data_block is None, 'Primitives are not memory stored' - - assert inp.type_var is not None, typing.ASSERTION_ERROR - mtyp = typing.simplify(inp.type_var) - if mtyp is None: - # In the future might extend this by having structs or tuples - # as members of struct or tuples - raise NotImplementedError(expression, inp, inp.type_var) - - expression(wgn, inp.definition.constant) - return - raise NotImplementedError(expression, inp) def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: @@ -566,38 +569,46 @@ def module_data(inp: ourlang.ModuleData) -> bytes: data_list: List[bytes] = [] - raise NotImplementedError('Broken after new type system') - for constant in block.data: - if isinstance(constant, ourlang.ConstantUInt8): + assert constant.type_var is not None + mtyp = typing.simplify(constant.type_var) + + if mtyp == 'u8': + assert isinstance(constant.value, int) data_list.append(module_data_u8(constant.value)) continue - if isinstance(constant, ourlang.ConstantUInt32): + if mtyp == 'u32': + assert isinstance(constant.value, int) data_list.append(module_data_u32(constant.value)) continue - if isinstance(constant, ourlang.ConstantUInt64): + if mtyp == 'u64': + assert isinstance(constant.value, int) data_list.append(module_data_u64(constant.value)) continue - if isinstance(constant, ourlang.ConstantInt32): + if mtyp == 'i32': + assert isinstance(constant.value, int) data_list.append(module_data_i32(constant.value)) continue - if isinstance(constant, ourlang.ConstantInt64): + if mtyp == 'i64': + assert isinstance(constant.value, int) data_list.append(module_data_i64(constant.value)) continue - if isinstance(constant, ourlang.ConstantFloat32): + if mtyp == 'f32': + assert isinstance(constant.value, float) data_list.append(module_data_f32(constant.value)) continue - if isinstance(constant, ourlang.ConstantFloat64): + if mtyp == 'f64': + assert isinstance(constant.value, float) data_list.append(module_data_f64(constant.value)) continue - raise NotImplementedError(constant) + raise NotImplementedError(constant, mtyp) block_data = b''.join(data_list) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 9e16e4b..9ffe2c4 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -11,10 +11,7 @@ WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) from .typing import ( - TypeBytes, - TypeTuple, TypeTupleMember, - TypeStaticArray, TypeStaticArrayMember, - TypeStruct, TypeStructMember, + TypeStruct, TypeVar, ) @@ -78,9 +75,9 @@ class VariableReference(Expression): """ __slots__ = ('variable', ) - variable: 'FunctionParam' # also possibly local + variable: Union['ModuleConstantDef', 'FunctionParam'] # also possibly local - def __init__(self, variable: 'FunctionParam') -> None: + def __init__(self, variable: Union['ModuleConstantDef', 'FunctionParam']) -> None: super().__init__() self.variable = variable @@ -131,9 +128,10 @@ class FunctionCall(Expression): self.function = function self.arguments = [] -class AccessBytesIndex(Expression): +class Subscript(Expression): """ - Access a bytes index for reading + A subscript, for example to refer to a static array or tuple + by index """ __slots__ = ('varref', 'index', ) @@ -146,53 +144,6 @@ class AccessBytesIndex(Expression): self.varref = varref self.index = index -class AccessStructMember(Expression): - """ - Access a struct member for reading of writing - """ - __slots__ = ('varref', 'member', ) - - varref: VariableReference - member: TypeStructMember - - def __init__(self, varref: VariableReference, member: TypeStructMember) -> None: - super().__init__() - - self.varref = varref - self.member = member - -class AccessTupleMember(Expression): - """ - Access a tuple member for reading of writing - """ - __slots__ = ('varref', 'member', ) - - varref: VariableReference - member: TypeTupleMember - - def __init__(self, varref: VariableReference, member: TypeTupleMember, ) -> None: - super().__init__() - - self.varref = varref - self.member = member - -class AccessStaticArrayMember(Expression): - """ - Access a tuple member for reading of writing - """ - __slots__ = ('varref', 'static_array', 'member', ) - - varref: Union['ModuleConstantReference', VariableReference] - static_array: TypeStaticArray - member: Union[Expression, TypeStaticArrayMember] - - def __init__(self, varref: Union['ModuleConstantReference', VariableReference], static_array: TypeStaticArray, member: Union[TypeStaticArrayMember, Expression], ) -> None: - super().__init__() - - self.varref = varref - self.static_array = static_array - self.member = member - class Fold(Expression): """ A (left or right) fold @@ -223,18 +174,6 @@ class Fold(Expression): self.base = base self.iter = iter_ -class ModuleConstantReference(Expression): - """ - An reference to a module constant expression within a statement - """ - __slots__ = ('definition', ) - - definition: 'ModuleConstantDef' - - def __init__(self, definition: 'ModuleConstantDef') -> None: - super().__init__() - self.definition = definition - class Statement: """ A statement within a function diff --git a/phasm/parser.py b/phasm/parser.py index 6fe96fc..00c4c16 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -22,15 +22,14 @@ from .ourlang import ( Function, Expression, - AccessBytesIndex, AccessStructMember, AccessTupleMember, AccessStaticArrayMember, BinaryOp, ConstantPrimitive, ConstantTuple, ConstantStaticArray, - FunctionCall, + FunctionCall, Subscript, # StructConstructor, TupleConstructor, UnaryOp, VariableReference, - Fold, ModuleConstantReference, + Fold, Statement, StatementIf, StatementPass, StatementReturn, @@ -206,6 +205,27 @@ class OurVisitor: None, ) + if isinstance(node.value, ast.Tuple): + tuple_data = [ + self.visit_Module_Constant(module, arg_node) + for arg_node in node.value.elts + if isinstance(arg_node, ast.Constant) + ] + if len(node.value.elts) != len(tuple_data): + _raise_static_error(node, 'Tuple arguments must be constants') + + # Allocate the data + data_block = ModuleDataBlock(tuple_data) + module.data.blocks.append(data_block) + + # Then return the constant as a pointer + return ModuleConstantDef( + node.target.id, + node.lineno, + ConstantTuple(tuple_data), + data_block, + ) + raise NotImplementedError('TODO: Broken after new typing system') # if isinstance(exp_type, TypeTuple): @@ -416,7 +436,7 @@ class OurVisitor: if node.id in module.constant_defs: cdef = module.constant_defs[node.id] - return ModuleConstantReference(cdef) + return VariableReference(cdef) _raise_static_error(node, f'Undefined variable {node.id}') @@ -454,13 +474,13 @@ class OurVisitor: if not isinstance(node.func.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.func.id in module.structs: - raise NotImplementedError('TODO: Broken after new type system') + # if node.func.id in module.structs: + # raise NotImplementedError('TODO: Broken after new type system') # struct = module.structs[node.func.id] # struct_constructor = StructConstructor(struct) # # func = module.functions[struct_constructor.name] - elif node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') @@ -533,61 +553,59 @@ class OurVisitor: def visit_Module_FunctionDef_Attribute(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Attribute) -> Expression: raise NotImplementedError('Broken after new type system') - del module - del function - - if not isinstance(node.value, ast.Name): - _raise_static_error(node, 'Must reference a name') - - if not isinstance(node.ctx, ast.Load): - _raise_static_error(node, 'Must be load context') - - if not node.value.id in our_locals: - _raise_static_error(node, f'Undefined variable {node.value.id}') - - param = our_locals[node.value.id] - - node_typ = param.type - if not isinstance(node_typ, TypeStruct): - _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') - - member = node_typ.get_member(node.attr) - if member is None: - _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') - - return AccessStructMember( - VariableReference(param), - member, - ) - - def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: - raise NotImplementedError('TODO: Broken after new type system') - + # del module + # del function + # # if not isinstance(node.value, ast.Name): # _raise_static_error(node, 'Must reference a name') # - # if not isinstance(node.slice, ast.Index): - # _raise_static_error(node, 'Must subscript using an index') - # # if not isinstance(node.ctx, ast.Load): # _raise_static_error(node, 'Must be load context') # - # varref: Union[ModuleConstantReference, VariableReference] - # if node.value.id in our_locals: - # param = our_locals[node.value.id] - # node_typ = param.type - # varref = VariableReference(param) - # elif node.value.id in module.constant_defs: - # constant_def = module.constant_defs[node.value.id] - # node_typ = constant_def.type - # varref = ModuleConstantReference(constant_def) - # else: + # if not node.value.id in our_locals: # _raise_static_error(node, f'Undefined variable {node.value.id}') # - # slice_expr = self.visit_Module_FunctionDef_expr( - # module, function, our_locals, node.slice.value, - # ) + # param = our_locals[node.value.id] # + # node_typ = param.type + # if not isinstance(node_typ, TypeStruct): + # _raise_static_error(node, f'Cannot take attribute of non-struct {node.value.id}') + # + # member = node_typ.get_member(node.attr) + # if member is None: + # _raise_static_error(node, f'{node_typ.name} has no attribute {node.attr}') + # + # return AccessStructMember( + # VariableReference(param), + # member, + # ) + + def visit_Module_FunctionDef_Subscript(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Subscript) -> Expression: + if not isinstance(node.value, ast.Name): + _raise_static_error(node, 'Must reference a name') + + if not isinstance(node.slice, ast.Index): + _raise_static_error(node, 'Must subscript using an index') + + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') + + varref: VariableReference + if node.value.id in our_locals: + param = our_locals[node.value.id] + varref = VariableReference(param) + elif node.value.id in module.constant_defs: + constant_def = module.constant_defs[node.value.id] + varref = VariableReference(constant_def) + else: + _raise_static_error(node, f'Undefined variable {node.value.id}') + + slice_expr = self.visit_Module_FunctionDef_expr( + module, function, our_locals, node.slice.value, + ) + + return Subscript(varref, slice_expr) + # if isinstance(node_typ, TypeBytes): # if isinstance(varref, ModuleConstantReference): # raise NotImplementedError(f'{node} from module constant') diff --git a/phasm/typer.py b/phasm/typer.py index 2b17f63..e57e043 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -62,7 +62,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return constant(ctx, inp) if isinstance(inp, ourlang.VariableReference): - assert inp.variable.type_var is not None, inp + assert inp.variable.type_var is not None return inp.variable.type_var if isinstance(inp, ourlang.UnaryOp): @@ -112,13 +112,6 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return inp.function.returns_type_var - if isinstance(inp, ourlang.ModuleConstantReference): - assert inp.definition.type_var is not None - - inp.type_var = inp.definition.type_var - - return inp.definition.type_var - raise NotImplementedError(expression, inp) def function(ctx: Context, inp: ourlang.Function) -> None: diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py index ced9277..68bfdd4 100644 --- a/tests/integration/test_lang/test_static_array.py +++ b/tests/integration/test_lang/test_static_array.py @@ -7,21 +7,18 @@ from ..helpers import Suite @pytest.mark.integration_test @pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) -def test_static_array_module_constant(type_): +def test_module_constant(type_): code_py = f""" CONSTANT: {type_}[3] = (24, 57, 80, ) @exported def testEntry() -> {type_}: - return helper(CONSTANT) - -def helper(array: {type_}[3]) -> {type_}: - return array[0] + array[1] + array[2] + return CONSTANT[0] """ result = Suite(code_py).run_code() - assert 161 == result.returned_value + assert 24 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test @@ -43,6 +40,25 @@ def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}: assert 161 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +def test_function_call(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24, 57, 80, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT) + +def helper(array: {type_}[3]) -> {type_}: + return array[0] + array[1] + array[2] +""" + + result = Suite(code_py).run_code() + + assert 161 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + @pytest.mark.integration_test def test_static_array_constant_too_few_values(): code_py = """ -- 2.49.0 From 4f7608a60106a96c9258ae5ac48721403e4548c4 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 12:29:48 +0200 Subject: [PATCH 14/18] Fix: ModuleConstantDef type annotation was ignored --- phasm/ourlang.py | 6 ++++-- phasm/parser.py | 2 ++ phasm/typer.py | 5 ++++- tests/integration/test_lang/test_primitives.py | 15 +++++++++++++++ 4 files changed, 25 insertions(+), 3 deletions(-) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 9ffe2c4..733476f 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -296,17 +296,19 @@ class ModuleConstantDef: """ A constant definition within a module """ - __slots__ = ('name', 'lineno', 'type_var', 'constant', 'data_block', ) + __slots__ = ('name', 'lineno', 'type_str', 'type_var', 'constant', 'data_block', ) name: str lineno: int + type_str: str type_var: Optional[TypeVar] constant: Constant data_block: Optional['ModuleDataBlock'] - def __init__(self, name: str, lineno: int, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: + def __init__(self, name: str, lineno: int, type_str: str, constant: Constant, data_block: Optional['ModuleDataBlock']) -> None: self.name = name self.lineno = lineno + self.type_str = type_str self.type_var = None self.constant = constant self.data_block = data_block diff --git a/phasm/parser.py b/phasm/parser.py index 00c4c16..f6570d1 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -201,6 +201,7 @@ class OurVisitor: return ModuleConstantDef( node.target.id, node.lineno, + self.visit_type(module, node.annotation), self.visit_Module_Constant(module, node.value), None, ) @@ -222,6 +223,7 @@ class OurVisitor: return ModuleConstantDef( node.target.id, node.lineno, + self.visit_type(module, node.annotation), ConstantTuple(tuple_data), data_block, ) diff --git a/phasm/typer.py b/phasm/typer.py index e57e043..ee356eb 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -125,7 +125,10 @@ def function(ctx: Context, inp: ourlang.Function) -> None: def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> None: constant(ctx, inp.constant) - inp.type_var = ctx.new_var() + if inp.type_str is None: + inp.type_var = ctx.new_var() + else: + inp.type_var = from_str(ctx, inp.type_str, inp.type_str) assert inp.constant.type_var is not None ctx.unify(inp.type_var, inp.constant.type_var) diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py index d63d4ac..441dcdb 100644 --- a/tests/integration/test_lang/test_primitives.py +++ b/tests/integration/test_lang/test_primitives.py @@ -1,5 +1,7 @@ import pytest +from phasm.exceptions import TypingError + from ..helpers import Suite from ..constants import ALL_INT_TYPES, ALL_FLOAT_TYPES, COMPLETE_INT_TYPES, TYPE_MAP @@ -61,6 +63,19 @@ def testEntry() -> {type_}: assert 32.125 == result.returned_value +@pytest.mark.integration_test +def test_module_constant_entanglement(): + code_py = """ +CONSTANT: u8 = 1000 + +@exported +def testEntry() -> u32: + return 14 +""" + + with pytest.raises(TypingError, match='u8.*1000'): + Suite(code_py).run_code() + @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['u32', 'u64']) # FIXME: Support u8, requires an extra AND operation def test_logical_left_shift(type_): -- 2.49.0 From 5da45e78c2d246f9606e6a53fa1ee4864b56504f Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 13:50:20 +0200 Subject: [PATCH 15/18] More work on StaticArray Also naming fix, buildin => builtin. Removes the use of ConstantStaticArray, as this was context dependent --- phasm/codestyle.py | 8 +- phasm/compiler.py | 6 +- phasm/ourlang.py | 16 +- phasm/parser.py | 61 ++--- phasm/typer.py | 49 +++- phasm/typing.py | 209 +++++++++++++----- pylintrc | 2 +- tests/integration/runners.py | 6 +- tests/integration/test_code/__init__.py | 0 tests/integration/test_code/test_typing.py | 20 ++ .../integration/test_lang/test_primitives.py | 2 +- .../test_lang/test_static_array.py | 41 +++- 12 files changed, 297 insertions(+), 123 deletions(-) create mode 100644 tests/integration/test_code/__init__.py create mode 100644 tests/integration/test_code/test_typing.py diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 500b878..d57c1db 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -54,7 +54,7 @@ def expression(inp: ourlang.Expression) -> str: # could not fit in the given float type return str(inp.value) - if isinstance(inp, (ourlang.ConstantTuple, ourlang.ConstantStaticArray, )): + if isinstance(inp, ourlang.ConstantTuple): return '(' + ', '.join( expression(x) for x in inp.value @@ -65,8 +65,8 @@ def expression(inp: ourlang.Expression) -> str: if isinstance(inp, ourlang.UnaryOp): if ( - inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS - or inp.operator in ourlang.WEBASSEMBLY_BUILDIN_BYTES_OPS): + inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS + or inp.operator in ourlang.WEBASSEMBLY_BUILTIN_BYTES_OPS): return f'{inp.operator}({expression(inp.right)})' if inp.operator == 'cast': @@ -186,7 +186,7 @@ def module(inp: ourlang.Module) -> str: for func in inp.functions.values(): if func.lineno < 0: - # Buildin (-2) or auto generated (-1) + # Builtin (-2) or auto generated (-1) continue if result: diff --git a/phasm/compiler.py b/phasm/compiler.py index 826f21e..d17c197 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -247,11 +247,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: mtyp = typing.simplify(inp.type_var) if mtyp == 'f32': - if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f32.{inp.operator}') return if mtyp == 'f64': - if inp.operator in ourlang.WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if inp.operator in ourlang.WEBASSEMBLY_BUILTIN_FLOAT_OPS: wgn.add_statement(f'f64.{inp.operator}') return @@ -608,7 +608,7 @@ def module_data(inp: ourlang.ModuleData) -> bytes: data_list.append(module_data_f64(constant.value)) continue - raise NotImplementedError(constant, mtyp) + raise NotImplementedError(constant, constant.type_var, mtyp) block_data = b''.join(data_list) diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 733476f..b0e605d 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -7,8 +7,8 @@ import enum from typing_extensions import Final -WEBASSEMBLY_BUILDIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) -WEBASSEMBLY_BUILDIN_BYTES_OPS: Final = ('len', ) +WEBASSEMBLY_BUILTIN_FLOAT_OPS: Final = ('abs', 'sqrt', 'ceil', 'floor', 'trunc', 'nearest', ) +WEBASSEMBLY_BUILTIN_BYTES_OPS: Final = ('len', ) from .typing import ( TypeStruct, @@ -57,18 +57,6 @@ class ConstantTuple(Constant): super().__init__() self.value = value -class ConstantStaticArray(Constant): - """ - A StaticArray constant value expression within a statement - """ - __slots__ = ('value', ) - - value: List[ConstantPrimitive] - - def __init__(self, value: List[ConstantPrimitive]) -> None: # FIXME: Arrays of arrays? - super().__init__() - self.value = value - class VariableReference(Expression): """ An variable reference expression within a statement diff --git a/phasm/parser.py b/phasm/parser.py index f6570d1..51019be 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -6,24 +6,22 @@ from typing import Any, Dict, NoReturn, Union import ast from .typing import ( + BUILTIN_TYPES, + TypeStruct, TypeStructMember, - TypeTuple, - TypeTupleMember, - TypeStaticArray, - TypeStaticArrayMember, ) from .exceptions import StaticError from .ourlang import ( - WEBASSEMBLY_BUILDIN_FLOAT_OPS, + WEBASSEMBLY_BUILTIN_FLOAT_OPS, Module, ModuleDataBlock, Function, Expression, BinaryOp, - ConstantPrimitive, ConstantTuple, ConstantStaticArray, + ConstantPrimitive, ConstantTuple, FunctionCall, Subscript, # StructConstructor, TupleConstructor, @@ -482,7 +480,7 @@ class OurVisitor: # struct_constructor = StructConstructor(struct) # # func = module.functions[struct_constructor.name] - if node.func.id in WEBASSEMBLY_BUILDIN_FLOAT_OPS: + if node.func.id in WEBASSEMBLY_BUILTIN_FLOAT_OPS: if 1 != len(node.args): _raise_static_error(node, f'Function {node.func.id} requires 1 arguments but {len(node.args)} are given') @@ -686,7 +684,7 @@ class OurVisitor: if not isinstance(node.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - if node.id in ('u8', 'u32', 'u64', 'i32', 'i64', 'f32', 'f64'): # FIXME: Source this list somewhere + if node.id in BUILTIN_TYPES: return node.id raise NotImplementedError('TODO: Broken after type system') @@ -697,40 +695,21 @@ class OurVisitor: _raise_static_error(node, f'Unrecognized type {node.id}') if isinstance(node, ast.Subscript): - raise NotImplementedError('TODO: Broken after new type system') + if not isinstance(node.value, ast.Name): + _raise_static_error(node, 'Must be name') + if not isinstance(node.slice, ast.Index): + _raise_static_error(node, 'Must subscript using an index') + if not isinstance(node.slice.value, ast.Constant): + _raise_static_error(node, 'Must subscript using a constant index') + if not isinstance(node.slice.value.value, int): + _raise_static_error(node, 'Must subscript using a constant integer index') + if not isinstance(node.ctx, ast.Load): + _raise_static_error(node, 'Must be load context') - # if not isinstance(node.value, ast.Name): - # _raise_static_error(node, 'Must be name') - # if not isinstance(node.slice, ast.Index): - # _raise_static_error(node, 'Must subscript using an index') - # if not isinstance(node.slice.value, ast.Constant): - # _raise_static_error(node, 'Must subscript using a constant index') - # if not isinstance(node.slice.value.value, int): - # _raise_static_error(node, 'Must subscript using a constant integer index') - # if not isinstance(node.ctx, ast.Load): - # _raise_static_error(node, 'Must be load context') - # - # if node.value.id in module.types: - # member_type = module.types[node.value.id] - # else: - # _raise_static_error(node, f'Unrecognized type {node.value.id}') - # - # type_static_array = TypeStaticArray(member_type) - # - # offset = 0 - # - # for idx in range(node.slice.value.value): - # static_array_member = TypeStaticArrayMember(idx, offset) - # - # type_static_array.members.append(static_array_member) - # offset += member_type.alloc_size() - # - # key = f'{node.value.id}[{node.slice.value.value}]' - # - # if key not in module.types: - # module.types[key] = type_static_array - # - # return module.types[key] + if node.value.id not in BUILTIN_TYPES: # FIXME: Tuple of tuples? + _raise_static_error(node, f'Unrecognized type {node.value.id}') + + return f'{node.value.id}[{node.slice.value.value}]' if isinstance(node, ast.Tuple): raise NotImplementedError('TODO: Broken after new type system') diff --git a/phasm/typer.py b/phasm/typer.py index ee356eb..d18aa1c 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -3,7 +3,13 @@ Type checks and enriches the given ast """ from . import ourlang -from .typing import Context, TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeVar, from_str +from .exceptions import TypingError +from .typing import ( + Context, + TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript, + TypeVar, + from_str, +) def phasm_type(inp: ourlang.Module) -> None: module(inp) @@ -55,6 +61,19 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar: raise NotImplementedError(constant, inp, inp.value) + if isinstance(inp, ourlang.ConstantTuple): + result = ctx.new_var() + + result.add_constraint(TypeConstraintSubscript(members=( + constant(ctx, x) + for x in inp.value + ))) + result.add_location(str(inp.value)) + + inp.type_var = result + + return result + raise NotImplementedError(constant, inp) def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': @@ -63,6 +82,8 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': if isinstance(inp, ourlang.VariableReference): assert inp.variable.type_var is not None + + inp.type_var = inp.variable.type_var return inp.variable.type_var if isinstance(inp, ourlang.UnaryOp): @@ -112,6 +133,32 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return inp.function.returns_type_var + if isinstance(inp, ourlang.Subscript): + if not isinstance(inp.index, ourlang.ConstantPrimitive): + raise NotImplementedError(expression, inp, inp.index) + if not isinstance(inp.index.value, int): + raise NotImplementedError(expression, inp, inp.index.value) + + expression(ctx, inp.varref) + assert inp.varref.type_var is not None + + try: + # TODO: I'd much rather resolve this using the narrow functions + tc_subs = ctx.var_constraints[inp.varref.type_var.ctx_id][TypeConstraintSubscript] + except KeyError: + raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None + + assert isinstance(tc_subs, TypeConstraintSubscript) # type hint + + try: + # TODO: I'd much rather resolve this using the narrow functions + member = tc_subs.members[inp.index.value] + except IndexError: + raise TypingError(f'Type cannot be subscripted with index {inp.index.value}: {inp.varref.type_var}') from None + + inp.type_var = member + return member + raise NotImplementedError(expression, inp) def function(ctx: Context, inp: ourlang.Function) -> None: diff --git a/phasm/typing.py b/phasm/typing.py index 0bebf8d..25f85a7 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -1,9 +1,10 @@ """ The phasm type system """ -from typing import Dict, Iterable, Optional, List, Set, Type +from typing import Callable, Dict, Iterable, Optional, List, Set, Type import enum +import re from .exceptions import TypingError @@ -151,7 +152,7 @@ class TypeConstraintBase: """ Base class for classes implementing a contraint on a type """ - def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase': raise NotImplementedError('narrow', self, other) class TypeConstraintPrimitive(TypeConstraintBase): @@ -172,7 +173,7 @@ class TypeConstraintPrimitive(TypeConstraintBase): def __init__(self, primitive: Primitive) -> None: self.primitive = primitive - def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive': + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive': if not isinstance(other, TypeConstraintPrimitive): raise Exception('Invalid comparison') @@ -196,7 +197,7 @@ class TypeConstraintSigned(TypeConstraintBase): def __init__(self, signed: Optional[bool]) -> None: self.signed = signed - def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintSigned': + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSigned': if not isinstance(other, TypeConstraintSigned): raise Exception('Invalid comparison') @@ -239,7 +240,7 @@ class TypeConstraintBitWidth(TypeConstraintBase): if x <= maxb } - def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': if not isinstance(other, TypeConstraintBitWidth): raise Exception('Invalid comparison') @@ -251,7 +252,60 @@ class TypeConstraintBitWidth(TypeConstraintBase): return TypeConstraintBitWidth(oneof=new_oneof) def __repr__(self) -> str: - return 'BitWidth=oneof(' + ','.join(map(str, sorted(self.oneof))) + ')' + result = 'BitWidth=' + + items = list(sorted(self.oneof)) + if not items: + return result + + while items: + itm = items.pop(0) + result += str(itm) + + cnt = 0 + while cnt < len(items) and items[cnt] == itm + cnt + 1: + cnt += 1 + + if cnt == 1: + result += ',' + str(items[0]) + elif cnt > 1: + result += '..' + str(items[cnt - 1]) + + items = items[cnt:] + if items: + result += ',' + + return result + +class TypeConstraintSubscript(TypeConstraintBase): + """ + Contraint on allowing a type to be subscripted + """ + __slots__ = ('members', ) + + members: List['TypeVar'] + + def __init__(self, *, members: Iterable['TypeVar']) -> None: + self.members = list(members) + + def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSubscript': + if not isinstance(other, TypeConstraintSubscript): + raise Exception('Invalid comparison') + + if len(self.members) != len(other.members): + raise TypingNarrowProtoError('Member count does not match') + + newmembers = [] + for smb, omb in zip(self.members, other.members): + nmb = ctx.new_var() + ctx.unify(nmb, smb) + ctx.unify(nmb, omb) + newmembers.append(nmb) + + return TypeConstraintSubscript(members=newmembers) + + def __repr__(self) -> str: + return 'Subscript=(' + ','.join(map(repr, self.members)) + ')' class TypeVar: """ @@ -271,7 +325,7 @@ class TypeVar: csts = self.ctx.var_constraints[self.ctx_id] if newconst.__class__ in csts: - csts[newconst.__class__] = csts[newconst.__class__].narrow(newconst) + csts[newconst.__class__] = csts[newconst.__class__].narrow(self.ctx, newconst) else: csts[newconst.__class__] = newconst @@ -413,6 +467,93 @@ def simplify(inp: TypeVar) -> Optional[str]: return None +def make_u8(ctx: Context, location: str) -> TypeVar: + """ + Makes a u8 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + +def make_u32(ctx: Context, location: str) -> TypeVar: + """ + Makes a u32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + +def make_u64(ctx: Context, location: str) -> TypeVar: + """ + Makes a u64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location(location) + return result + +def make_i32(ctx: Context, location: str) -> TypeVar: + """ + Makes a i32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location(location) + return result + +def make_i64(ctx: Context, location: str) -> TypeVar: + """ + Makes a i64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) + result.add_constraint(TypeConstraintSigned(True)) + result.add_location(location) + return result + +def make_f32(ctx: Context, location: str) -> TypeVar: + """ + Makes a f32 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_location(location) + return result + +def make_f64(ctx: Context, location: str) -> TypeVar: + """ + Makes a f64 TypeVar + """ + result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_location(location) + return result + +BUILTIN_TYPES: Dict[str, Callable[[Context, str], TypeVar]] = { + 'u8': make_u8, + 'u32': make_u32, + 'u64': make_u64, + 'i32': make_i32, + 'i64': make_i64, + 'f32': make_f32, + 'f64': make_f64, +} + +TYPE_MATCH_STATIC_ARRAY = re.compile(r'^([uif][0-9]+)\[([0-9]+)\]') + def from_str(ctx: Context, inp: str, location: str) -> TypeVar: """ Creates a new TypeVar from the string @@ -425,53 +566,21 @@ def from_str(ctx: Context, inp: str, location: str) -> TypeVar: This could be conidered part of parsing. Though that would give trouble with the context creation. """ - result = ctx.new_var() + if inp in BUILTIN_TYPES: + return BUILTIN_TYPES[inp](ctx, location) - if inp == 'u8': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) - result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) - return result + match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp) + if match: + result = ctx.new_var() - if inp == 'u32': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_constraint(TypeConstraintSigned(False)) + result.add_constraint(TypeConstraintSubscript(members=( + # Make copies so they don't get entangled + # with each other. + from_str(ctx, match[1], match[1]) + for _ in range(int(match[2])) + ))) result.add_location(location) - return result - if inp == 'u64': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) - return result - - if inp == 'i32': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_constraint(TypeConstraintSigned(True)) - result.add_location(location) - return result - - if inp == 'i64': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_constraint(TypeConstraintSigned(True)) - result.add_location(location) - return result - - if inp == 'f32': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_location(location) - return result - - if inp == 'f64': - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) - result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_location(location) return result raise NotImplementedError(from_str, inp) diff --git a/pylintrc b/pylintrc index 82948bb..f872f8d 100644 --- a/pylintrc +++ b/pylintrc @@ -7,4 +7,4 @@ max-line-length=180 good-names=g [tests] -disable=C0116, +disable=C0116,R0201 diff --git a/tests/integration/runners.py b/tests/integration/runners.py index 005d44e..77cd3f5 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, Iterable, Optional, TextIO import ctypes import io -import warnings import pywasm.binary import wasm3 @@ -42,10 +41,7 @@ class RunnerBase: Parses the Phasm code into an AST """ self.phasm_ast = phasm_parse(self.phasm_code) - try: - phasm_type(self.phasm_ast) - except NotImplementedError as exc: - warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}') + phasm_type(self.phasm_ast) def compile_ast(self) -> None: """ diff --git a/tests/integration/test_code/__init__.py b/tests/integration/test_code/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_code/test_typing.py b/tests/integration/test_code/test_typing.py new file mode 100644 index 0000000..7be1140 --- /dev/null +++ b/tests/integration/test_code/test_typing.py @@ -0,0 +1,20 @@ +import pytest + +from phasm import typing as sut + +class TestTypeConstraintBitWidth: + @pytest.mark.parametrize('oneof,exp', [ + (set(), '', ), + ({1}, '1', ), + ({1,2}, '1,2', ), + ({1,2,3}, '1..3', ), + ({1,2,3,4}, '1..4', ), + + ({1,3}, '1,3', ), + ({1,4}, '1,4', ), + + ({1,2,3,4,6,7,8,9}, '1..4,6..9', ), + ]) + def test_repr(self, oneof, exp): + mut_self = sut.TypeConstraintBitWidth(oneof=oneof) + assert ('BitWidth=' + exp) == repr(mut_self) diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py index 441dcdb..5736f0b 100644 --- a/tests/integration/test_lang/test_primitives.py +++ b/tests/integration/test_lang/test_primitives.py @@ -219,7 +219,7 @@ def testEntry() -> {type_}: @pytest.mark.integration_test @pytest.mark.parametrize('type_', ['f32', 'f64']) -def test_buildins_sqrt(type_): +def test_builtins_sqrt(type_): code_py = f""" @exported def testEntry() -> {type_}: diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py index 68bfdd4..6ea5985 100644 --- a/tests/integration/test_lang/test_static_array.py +++ b/tests/integration/test_lang/test_static_array.py @@ -1,12 +1,12 @@ import pytest -from phasm.exceptions import StaticError +from phasm.exceptions import StaticError, TypingError -from ..constants import COMPLETE_PRIMITIVE_TYPES, TYPE_MAP +from ..constants import ALL_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP from ..helpers import Suite @pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) +@pytest.mark.parametrize('type_', ALL_INT_TYPES) def test_module_constant(type_): code_py = f""" CONSTANT: {type_}[3] = (24, 57, 80, ) @@ -59,6 +59,41 @@ def helper(array: {type_}[3]) -> {type_}: assert 161 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +def test_module_constant_type_mismatch_bitwidth(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 280, ) +""" + + with pytest.raises(TypingError, match='u8.*280'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_not_subscriptable(): + code_py = """ +CONSTANT: u8 = 24 + +@exported +def testEntry() -> u8: + return CONSTANT[0] +""" + + with pytest.raises(TypingError, match='Type cannot be subscripted:'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_module_constant_type_mismatch_index_out_of_range(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +@exported +def testEntry() -> u8: + return CONSTANT[3] +""" + + with pytest.raises(TypingError, match='Type cannot be subscripted with index 3:'): + Suite(code_py).run_code() + @pytest.mark.integration_test def test_static_array_constant_too_few_values(): code_py = """ -- 2.49.0 From 4d3c0c6c3ce97347627266d9f56c5815811ff360 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 14:43:15 +0200 Subject: [PATCH 16/18] StaticArray with constant index works again Also, fix issue with f64 being parsed as f32 --- phasm/codestyle.py | 18 ++-- phasm/compiler.py | 84 ++++++++++++++++--- phasm/typer.py | 9 +- phasm/typing.py | 49 +++++++---- .../test_lang/test_static_array.py | 41 +++++++-- 5 files changed, 148 insertions(+), 53 deletions(-) diff --git a/phasm/codestyle.py b/phasm/codestyle.py index d57c1db..1af79ca 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -95,18 +95,16 @@ def expression(inp: ourlang.Expression) -> str: # return f'({args}, )' return f'{inp.function.name}({args})' - # - # if isinstance(inp, ourlang.AccessBytesIndex): - # return f'{expression(inp.varref)}[{expression(inp.index)}]' - # + + if isinstance(inp, ourlang.Subscript): + varref = expression(inp.varref) + index = expression(inp.index) + + return f'{varref}[{index}]' + + # TODO: Broken after new type system # if isinstance(inp, ourlang.AccessStructMember): # return f'{expression(inp.varref)}.{inp.member.name}' - # - # if isinstance(inp, (ourlang.AccessTupleMember, ourlang.AccessStaticArrayMember, )): - # if isinstance(inp.member, ourlang.Expression): - # return f'{expression(inp.varref)}[{expression(inp.member)}]' - # - # return f'{expression(inp.varref)}[{inp.member.idx}]' if isinstance(inp, ourlang.Fold): fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' diff --git a/phasm/compiler.py b/phasm/compiler.py index d17c197..a83c2cf 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -24,6 +24,9 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module: def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType: """ Compile: type + + Types are used for example in WebAssembly function parameters + and return types. """ assert inp is not None, typing.ASSERTION_ERROR @@ -52,6 +55,16 @@ def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType: if mtyp == 'f64': return wasm.WasmTypeFloat64() + assert inp is not None, typing.ASSERTION_ERROR + tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive) + if tc_prim is None: + raise NotImplementedError(type_var, inp) + + if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + # StaticArray, Tuples and Structs are passed as pointer + # And pointers are i32 + return wasm.WasmTypeInt32() + # TODO: Broken after new type system # if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): # # Structs and tuples are passed as pointer @@ -161,7 +174,12 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: return if isinstance(inp.variable, ourlang.ModuleConstantDef): - # FIXME: Tuple / Static Array broken after new type system + assert inp.variable.type_var is not None, typing.ASSERTION_ERROR + tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive) + if tc_prim is None: + raise NotImplementedError(expression, inp, inp.variable.type_var) + + # TODO: Broken after new type system # if isinstance(inp.type, typing.TypeTuple): # assert isinstance(inp.definition.constant, ourlang.ConstantTuple) # assert inp.definition.data_block is not None, 'Combined values are memory stored' @@ -169,12 +187,12 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: # wgn.i32.const(inp.definition.data_block.address) # return # - # if isinstance(inp.type, typing.TypeStaticArray): - # assert isinstance(inp.definition.constant, ourlang.ConstantStaticArray) - # assert inp.definition.data_block is not None, 'Combined values are memory stored' - # assert inp.definition.data_block.address is not None, 'Value not allocated' - # wgn.i32.const(inp.definition.data_block.address) - # return + + if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + assert inp.variable.data_block is not None, 'Combined values are memory stored' + assert inp.variable.data_block.address is not None, 'Value not allocated' + wgn.i32.const(inp.variable.data_block.address) + return assert inp.variable.data_block is None, 'Primitives are not memory stored' @@ -276,6 +294,53 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.add_statement('call', '${}'.format(inp.function.name)) return + if isinstance(inp, ourlang.Subscript): + assert inp.varref.type_var is not None, typing.ASSERTION_ERROR + tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive) + if tc_prim is None: + raise NotImplementedError(expression, inp, inp.varref.type_var) + + if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + if not isinstance(inp.index, ourlang.ConstantPrimitive): + raise NotImplementedError(expression, inp, inp.index) + if not isinstance(inp.index.value, int): + raise NotImplementedError(expression, inp, inp.index.value) + + assert inp.type_var is not None, typing.ASSERTION_ERROR + mtyp = typing.simplify(inp.type_var) + if mtyp is None: + raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp) + + if mtyp == 'u8': + # u8 operations are done using i32, since WASM does not have u8 operations + mtyp = 'i32' + elif mtyp == 'u32': + # u32 operations are done using i32, using _u operations + mtyp = 'i32' + elif mtyp == 'u64': + # u64 operations are done using i64, using _u operations + mtyp = 'i64' + + tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript) + if tc_subs is None: + raise NotImplementedError(expression, inp, inp.varref.type_var) + + assert 0 < len(tc_subs.members) + tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth) + if tc_bits is None or len(tc_bits.oneof) > 1: + raise NotImplementedError(expression, inp, inp.varref.type_var) + + bitwidth = next(iter(tc_bits.oneof)) + if bitwidth % 8 != 0: + raise NotImplementedError(expression, inp, inp.varref.type_var) + + expression(wgn, inp.varref) + wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value)) + return + + raise NotImplementedError(expression, inp, inp.varref.type_var) + + # TODO: Broken after new type system # if isinstance(inp, ourlang.AccessBytesIndex): # if not isinstance(inp.type, typing.TypeUInt8): @@ -315,11 +380,6 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: # # as members of static arrays # raise NotImplementedError(expression, inp, inp.member) # - # if isinstance(inp.member, typing.TypeStaticArrayMember): - # expression(wgn, inp.varref) - # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(inp.member.offset)) - # return - # # expression(wgn, inp.varref) # expression(wgn, inp.member) # wgn.i32.const(inp.static_array.member_type.alloc_size()) diff --git a/phasm/typer.py b/phasm/typer.py index d18aa1c..7bc6434 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -142,14 +142,11 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': expression(ctx, inp.varref) assert inp.varref.type_var is not None - try: - # TODO: I'd much rather resolve this using the narrow functions - tc_subs = ctx.var_constraints[inp.varref.type_var.ctx_id][TypeConstraintSubscript] - except KeyError: + # TODO: I'd much rather resolve this using the narrow functions + tc_subs = inp.varref.type_var.get_constraint(TypeConstraintSubscript) + if tc_subs is None: raise TypingError(f'Type cannot be subscripted: {inp.varref.type_var}') from None - assert isinstance(tc_subs, TypeConstraintSubscript) # type hint - try: # TODO: I'd much rather resolve this using the narrow functions member = tc_subs.members[inp.index.value] diff --git a/phasm/typing.py b/phasm/typing.py index 25f85a7..84e5666 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -2,6 +2,7 @@ The phasm type system """ from typing import Callable, Dict, Iterable, Optional, List, Set, Type +from typing import TypeVar as MyPyTypeVar import enum import re @@ -168,6 +169,8 @@ class TypeConstraintPrimitive(TypeConstraintBase): INT = 0 FLOAT = 1 + STATIC_ARRAY = 10 + primitive: Primitive def __init__(self, primitive: Primitive) -> None: @@ -307,6 +310,8 @@ class TypeConstraintSubscript(TypeConstraintBase): def __repr__(self) -> str: return 'Subscript=(' + ','.join(map(repr, self.members)) + ')' +TTypeConstraintClass = MyPyTypeVar('TTypeConstraintClass', bound=TypeConstraintBase) + class TypeVar: """ A type variable @@ -329,15 +334,22 @@ class TypeVar: else: csts[newconst.__class__] = newconst + def get_constraint(self, const_type: Type[TTypeConstraintClass]) -> Optional[TTypeConstraintClass]: + csts = self.ctx.var_constraints[self.ctx_id] + + res = csts.get(const_type, None) + assert res is None or isinstance(res, const_type) # type hint + return res + def add_location(self, ref: str) -> None: - self.ctx.var_locations[self.ctx_id].append(ref) + self.ctx.var_locations[self.ctx_id].add(ref) def __repr__(self) -> str: return ( 'TypeVar<' + '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values())) + '; locations: ' - + ', '.join(self.ctx.var_locations[self.ctx_id]) + + ', '.join(sorted(self.ctx.var_locations[self.ctx_id])) + '>' ) @@ -356,7 +368,7 @@ class Context: # Store the TypeVar properties as a lookup # so we can update these when unifying self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {} - self.var_locations: Dict[int, List[str]] = {} + self.var_locations: Dict[int, Set[str]] = {} def new_var(self) -> TypeVar: ctx_id = self.next_ctx_id @@ -366,7 +378,7 @@ class Context: self.vars_by_id[ctx_id] = [result] self.var_constraints[ctx_id] = {} - self.var_locations[ctx_id] = [] + self.var_locations[ctx_id] = set() return result @@ -395,8 +407,7 @@ class Context: except TypingNarrowProtoError as exc: raise TypingNarrowError(l, r, str(exc)) from None - self.var_locations[n.ctx_id].extend(self.var_locations[l_ctx_id]) - self.var_locations[n.ctx_id].extend(self.var_locations[r_ctx_id]) + self.var_locations[n.ctx_id] = self.var_locations[l_ctx_id] | self.var_locations[r_ctx_id] # ## # And unify (or entangle) the old ones @@ -424,22 +435,18 @@ def simplify(inp: TypeVar) -> Optional[str]: Should round trip with from_str """ - tc_prim = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintPrimitive) - tc_bits = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintBitWidth) - tc_sign = inp.ctx.var_constraints[inp.ctx_id].get(TypeConstraintSigned) + tc_prim = inp.get_constraint(TypeConstraintPrimitive) + tc_bits = inp.get_constraint(TypeConstraintBitWidth) + tc_sign = inp.get_constraint(TypeConstraintSigned) if tc_prim is None: return None - assert isinstance(tc_prim, TypeConstraintPrimitive) # type hint primitive = tc_prim.primitive if primitive is TypeConstraintPrimitive.Primitive.INT: if tc_bits is None or tc_sign is None: return None - assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint - assert isinstance(tc_sign, TypeConstraintSigned) # type hint - if tc_sign.signed is None or len(tc_bits.oneof) != 1: return None @@ -454,8 +461,6 @@ def simplify(inp: TypeVar) -> Optional[str]: if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint return None - assert isinstance(tc_bits, TypeConstraintBitWidth) # type hint - if len(tc_bits.oneof) != 1: return None @@ -465,6 +470,17 @@ def simplify(inp: TypeVar) -> Optional[str]: return f'f{bitwidth}' + if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + tc_subs = inp.get_constraint(TypeConstraintSubscript) + assert tc_subs is not None + assert tc_subs.members + + sab = simplify(tc_subs.members[0]) + if sab is None: + return None + + return f'{sab}[{len(tc_subs.members)}]' + return None def make_u8(ctx: Context, location: str) -> TypeVar: @@ -538,7 +554,7 @@ def make_f64(ctx: Context, location: str) -> TypeVar: """ result = ctx.new_var() result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) - result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) + result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_location(location) return result @@ -573,6 +589,7 @@ def from_str(ctx: Context, inp: str, location: str) -> TypeVar: if match: result = ctx.new_var() + result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY)) result.add_constraint(TypeConstraintSubscript(members=( # Make copies so they don't get entangled # with each other. diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py index 6ea5985..5708fb1 100644 --- a/tests/integration/test_lang/test_static_array.py +++ b/tests/integration/test_lang/test_static_array.py @@ -2,7 +2,9 @@ import pytest from phasm.exceptions import StaticError, TypingError -from ..constants import ALL_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP +from ..constants import ( + ALL_FLOAT_TYPES, ALL_INT_TYPES, COMPLETE_INT_TYPES, COMPLETE_PRIMITIVE_TYPES, TYPE_MAP +) from ..helpers import Suite @pytest.mark.integration_test @@ -22,6 +24,7 @@ def testEntry() -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test +@pytest.mark.skip('To decide: What to do on out of index?') @pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) def test_static_array_indexed(type_): code_py = f""" @@ -41,8 +44,8 @@ def helper(array: {type_}[3], i0: u32, i1: u32, i2: u32) -> {type_}: assert TYPE_MAP[type_] == type(result.returned_value) @pytest.mark.integration_test -@pytest.mark.parametrize('type_', COMPLETE_PRIMITIVE_TYPES) -def test_function_call(type_): +@pytest.mark.parametrize('type_', COMPLETE_INT_TYPES) +def test_function_call_int(type_): code_py = f""" CONSTANT: {type_}[3] = (24, 57, 80, ) @@ -59,6 +62,25 @@ def helper(array: {type_}[3]) -> {type_}: assert 161 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +@pytest.mark.parametrize('type_', ALL_FLOAT_TYPES) +def test_function_call_float(type_): + code_py = f""" +CONSTANT: {type_}[3] = (24.0, 57.5, 80.75, ) + +@exported +def testEntry() -> {type_}: + return helper(CONSTANT) + +def helper(array: {type_}[3]) -> {type_}: + return array[0] + array[1] + array[2] +""" + + result = Suite(code_py).run_code() + + assert 162.25 == result.returned_value + assert TYPE_MAP[type_] == type(result.returned_value) + @pytest.mark.integration_test def test_module_constant_type_mismatch_bitwidth(): code_py = """ @@ -100,8 +122,8 @@ def test_static_array_constant_too_few_values(): CONSTANT: u8[3] = (24, 57, ) """ - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): - phasm_parse(code_py) + with pytest.raises(TypingError, match='Member count does not match'): + Suite(code_py).run_code() @pytest.mark.integration_test def test_static_array_constant_too_many_values(): @@ -109,8 +131,8 @@ def test_static_array_constant_too_many_values(): CONSTANT: u8[3] = (24, 57, 1, 1, ) """ - with pytest.raises(StaticError, match='Static error on line 2: Invalid number of static array values'): - phasm_parse(code_py) + with pytest.raises(TypingError, match='Member count does not match'): + Suite(code_py).run_code() @pytest.mark.integration_test def test_static_array_constant_type_mismatch(): @@ -118,10 +140,11 @@ def test_static_array_constant_type_mismatch(): CONSTANT: u8[3] = (24, 4000, 1, ) """ - with pytest.raises(StaticError, match='Static error on line 2: Integer value out of range; expected 0..255, actual 4000'): - phasm_parse(code_py) + with pytest.raises(TypingError, match='u8.*4000'): + Suite(code_py).run_code() @pytest.mark.integration_test +@pytest.mark.skip('To decide: What to do on out of index?') def test_static_array_index_out_of_bounds(): code_py = """ CONSTANT0: u32[3] = (24, 57, 80, ) -- 2.49.0 From 2a6da91eb9e27978b3b3893f17971974c27c1efa Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 14:53:22 +0200 Subject: [PATCH 17/18] Simplified locations, adds typing tests --- phasm/typer.py | 2 +- phasm/typing.py | 43 +++++++++++-------- .../integration/test_lang/test_primitives.py | 11 +++++ .../test_lang/test_static_array.py | 12 ++++++ 4 files changed, 49 insertions(+), 19 deletions(-) diff --git a/phasm/typer.py b/phasm/typer.py index 7bc6434..56b1f05 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -172,7 +172,7 @@ def module_constant_def(ctx: Context, inp: ourlang.ModuleConstantDef) -> None: if inp.type_str is None: inp.type_var = ctx.new_var() else: - inp.type_var = from_str(ctx, inp.type_str, inp.type_str) + inp.type_var = from_str(ctx, inp.type_str) assert inp.constant.type_var is not None ctx.unify(inp.type_var, inp.constant.type_var) diff --git a/phasm/typing.py b/phasm/typing.py index 84e5666..468e537 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -483,7 +483,7 @@ def simplify(inp: TypeVar) -> Optional[str]: return None -def make_u8(ctx: Context, location: str) -> TypeVar: +def make_u8(ctx: Context) -> TypeVar: """ Makes a u8 TypeVar """ @@ -491,10 +491,10 @@ def make_u8(ctx: Context, location: str) -> TypeVar: result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) + result.add_location('u8') return result -def make_u32(ctx: Context, location: str) -> TypeVar: +def make_u32(ctx: Context) -> TypeVar: """ Makes a u32 TypeVar """ @@ -502,10 +502,10 @@ def make_u32(ctx: Context, location: str) -> TypeVar: result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) + result.add_location('u32') return result -def make_u64(ctx: Context, location: str) -> TypeVar: +def make_u64(ctx: Context) -> TypeVar: """ Makes a u64 TypeVar """ @@ -513,10 +513,10 @@ def make_u64(ctx: Context, location: str) -> TypeVar: result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(False)) - result.add_location(location) + result.add_location('u64') return result -def make_i32(ctx: Context, location: str) -> TypeVar: +def make_i32(ctx: Context) -> TypeVar: """ Makes a i32 TypeVar """ @@ -524,10 +524,10 @@ def make_i32(ctx: Context, location: str) -> TypeVar: result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(True)) - result.add_location(location) + result.add_location('i32') return result -def make_i64(ctx: Context, location: str) -> TypeVar: +def make_i64(ctx: Context) -> TypeVar: """ Makes a i64 TypeVar """ @@ -535,30 +535,30 @@ def make_i64(ctx: Context, location: str) -> TypeVar: result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(True)) - result.add_location(location) + result.add_location('i64') return result -def make_f32(ctx: Context, location: str) -> TypeVar: +def make_f32(ctx: Context) -> TypeVar: """ Makes a f32 TypeVar """ result = ctx.new_var() result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) - result.add_location(location) + result.add_location('f32') return result -def make_f64(ctx: Context, location: str) -> TypeVar: +def make_f64(ctx: Context) -> TypeVar: """ Makes a f64 TypeVar """ result = ctx.new_var() result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) - result.add_location(location) + result.add_location('f64') return result -BUILTIN_TYPES: Dict[str, Callable[[Context, str], TypeVar]] = { +BUILTIN_TYPES: Dict[str, Callable[[Context], TypeVar]] = { 'u8': make_u8, 'u32': make_u32, 'u64': make_u64, @@ -570,7 +570,7 @@ BUILTIN_TYPES: Dict[str, Callable[[Context, str], TypeVar]] = { TYPE_MATCH_STATIC_ARRAY = re.compile(r'^([uif][0-9]+)\[([0-9]+)\]') -def from_str(ctx: Context, inp: str, location: str) -> TypeVar: +def from_str(ctx: Context, inp: str, location: Optional[str] = None) -> TypeVar: """ Creates a new TypeVar from the string @@ -583,7 +583,10 @@ def from_str(ctx: Context, inp: str, location: str) -> TypeVar: with the context creation. """ if inp in BUILTIN_TYPES: - return BUILTIN_TYPES[inp](ctx, location) + result = BUILTIN_TYPES[inp](ctx) + if location is not None: + result.add_location(location) + return result match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp) if match: @@ -596,7 +599,11 @@ def from_str(ctx: Context, inp: str, location: str) -> TypeVar: from_str(ctx, match[1], match[1]) for _ in range(int(match[2])) ))) - result.add_location(location) + + result.add_location(inp) + + if location is not None: + result.add_location(location) return result diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py index 5736f0b..076b308 100644 --- a/tests/integration/test_lang/test_primitives.py +++ b/tests/integration/test_lang/test_primitives.py @@ -33,6 +33,17 @@ def testEntry() -> {type_}: assert 32.125 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) +@pytest.mark.integration_test +def test_expr_constant_entanglement(): + code_py = """ +@exported +def testEntry() -> u8: + return 1000 +""" + + with pytest.raises(TypingError, match='u8.*1000'): + Suite(code_py).run_code() + @pytest.mark.integration_test @pytest.mark.parametrize('type_', ALL_INT_TYPES) def test_module_constant_int(type_): diff --git a/tests/integration/test_lang/test_static_array.py b/tests/integration/test_lang/test_static_array.py index 5708fb1..9f549da 100644 --- a/tests/integration/test_lang/test_static_array.py +++ b/tests/integration/test_lang/test_static_array.py @@ -90,6 +90,18 @@ CONSTANT: u8[3] = (24, 57, 280, ) with pytest.raises(TypingError, match='u8.*280'): Suite(code_py).run_code() +@pytest.mark.integration_test +def test_return_as_int(): + code_py = """ +CONSTANT: u8[3] = (24, 57, 80, ) + +def testEntry() -> u32: + return CONSTANT +""" + + with pytest.raises(TypingError, match=r'u32.*u8\[3\]'): + Suite(code_py).run_code() + @pytest.mark.integration_test def test_module_constant_type_mismatch_not_subscriptable(): code_py = """ -- 2.49.0 From 977c449c3fd9600c6fab291ec14aba44ae10358e Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Mon, 19 Sep 2022 14:55:05 +0200 Subject: [PATCH 18/18] Removed redundant file --- .../integration/test_lang/test_primitives.py | 14 +++++++++ .../integration/test_lang/test_type_checks.py | 31 ------------------- 2 files changed, 14 insertions(+), 31 deletions(-) delete mode 100644 tests/integration/test_lang/test_type_checks.py diff --git a/tests/integration/test_lang/test_primitives.py b/tests/integration/test_lang/test_primitives.py index 076b308..01df9cf 100644 --- a/tests/integration/test_lang/test_primitives.py +++ b/tests/integration/test_lang/test_primitives.py @@ -360,3 +360,17 @@ def helper(left: {type_}, right: {type_}) -> {type_}: assert 32.125 == result.returned_value assert TYPE_MAP[type_] == type(result.returned_value) + +@pytest.mark.integration_test +def test_call_invalid_type(): + code_py = """ +def helper() -> i64: + return 19 + +@exported +def testEntry() -> i32: + return helper() +""" + + with pytest.raises(TypingError, match=r'i32.*i64'): + Suite(code_py).run_code() diff --git a/tests/integration/test_lang/test_type_checks.py b/tests/integration/test_lang/test_type_checks.py deleted file mode 100644 index 1389c2f..0000000 --- a/tests/integration/test_lang/test_type_checks.py +++ /dev/null @@ -1,31 +0,0 @@ -import pytest - -from phasm.parser import phasm_parse -from phasm.typer import phasm_type -from phasm.exceptions import TypingError - -@pytest.mark.integration_test -def test_constant_too_wide(): - code_py = """ -def func_const() -> u8: - return 0xFFF -""" - - ast = phasm_parse(code_py) - with pytest.raises(TypingError, match='Other min bitwidth exceeds max bitwidth'): - phasm_type(ast) - -@pytest.mark.integration_test -@pytest.mark.parametrize('type_', [32, 64]) -def test_signed_mismatch(type_): - code_py = f""" -def func_const() -> u{type_}: - return 0 - -def func_call() -> i{type_}: - return func_const() -""" - - ast = phasm_parse(code_py) - with pytest.raises(TypingError, match='Signed does not match'): - phasm_type(ast) -- 2.49.0