diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 5f19d2e..3338ce4 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -21,18 +21,22 @@ from .typing import ( TypeTuple, TypeTupleMember, TypeStaticArray, TypeStaticArrayMember, TypeStruct, TypeStructMember, + + TypeVar, ) class Expression: """ An expression within a statement """ - __slots__ = ('type', ) + __slots__ = ('type', 'type_var', ) type: TypeBase + type_var: Optional[TypeVar] def __init__(self, type_: TypeBase) -> None: self.type = type_ + self.type_var = None class Constant(Expression): """ diff --git a/phasm/typer.py b/phasm/typer.py new file mode 100644 index 0000000..d21f823 --- /dev/null +++ b/phasm/typer.py @@ -0,0 +1,71 @@ +""" +Type checks and enriches the given ast +""" +from math import ceil, log2 + +from . import ourlang + +from .typing import Context, TypeConstraintBitWidth, TypeConstraintSigned, TypeVar + +def phasm_type(inp: ourlang.Module) -> None: + module(inp) + +def constant(ctx: 'Context', inp: ourlang.Constant) -> 'TypeVar': + if getattr(inp, 'value', int): + result = ctx.new_var() + + # Need at least this many bits to store this constant value + result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) # type: ignore + # Don't dictate anything about signedness - you can use a signed + # constant in an unsigned variable if the bits fit + result.add_constraint(TypeConstraintSigned(None)) + + result.add_location(str(inp.value)) # type: ignore + + inp.type_var = result + + return result + + raise NotImplementedError(constant, inp) + +def expression(ctx: 'Context', inp: ourlang.Expression) -> 'TypeVar': + if isinstance(inp, ourlang.Constant): + return constant(ctx, inp) + + if isinstance(inp, ourlang.BinaryOp): + left = expression(ctx, inp.left) + right = expression(ctx, inp.right) + ctx.unify(left, right) + return left + + raise NotImplementedError(expression, inp) + +def function(ctx: 'Context', inp: ourlang.Function) -> None: + bctx = ctx.clone() # Clone whenever we go into a block + + assert len(inp.statements) == 1 # TODO + + assert isinstance(inp.statements[0], ourlang.StatementReturn) + typ = expression(ctx, inp.statements[0].value) + + ctx.unify(_convert_old_type(ctx, inp.returns), typ) + return + +def module(inp: ourlang.Module) -> None: + ctx = Context() + + for func in inp.functions.values(): + function(ctx, func) + +from . import typing + +def _convert_old_type(ctx: Context, inp: typing.TypeBase) -> TypeVar: + result = ctx.new_var() + + if isinstance(inp, typing.TypeUInt32): + result.add_constraint(TypeConstraintBitWidth(maxb=32)) + result.add_constraint(TypeConstraintSigned(False)) + result.add_location('u32') + return result + + raise NotImplementedError(_convert_old_type, inp) diff --git a/phasm/typing.py b/phasm/typing.py index e56f7a9..0cb213d 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -1,7 +1,7 @@ """ The phasm type system """ -from typing import Optional, List +from typing import Dict, Optional, List, Type class TypeBase: """ @@ -200,3 +200,133 @@ class TypeStruct(TypeBase): x.type.alloc_size() for x in self.members ) + +## NEW STUFF BELOW + +class TypingError(Exception): + pass + +class TypingNarrowProtoError(TypingError): + pass + +class TypingNarrowError(TypingError): + def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None: + super().__init__( + f'Cannot narrow types {l} and {r}: {msg}' + ) + +class TypeConstraintBase: + def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBase': + raise NotImplementedError('narrow', self, other) + +class TypeConstraintSigned(TypeConstraintBase): + __slots__ = ('signed', ) + + signed: Optional[bool] + + def __init__(self, signed: Optional[bool]) -> None: + self.signed = signed + + def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintSigned': + if not isinstance(other, TypeConstraintSigned): + raise Exception('Invalid comparison') + + if other.signed is None: + return TypeConstraintSigned(self.signed) + if self.signed is None: + return TypeConstraintSigned(other.signed) + + if self.signed is not other.signed: + raise TypeError() + + return TypeConstraintSigned(self.signed) + + def __repr__(self) -> str: + return f'Signed={self.signed}' + +class TypeConstraintBitWidth(TypeConstraintBase): + __slots__ = ('minb', 'maxb', ) + + minb: int + maxb: int + + def __init__(self, *, minb: int = 1, maxb: int = 64) -> None: + assert minb is not None or maxb is not None + assert maxb <= 64 # For now, support up to 64 bits values + + self.minb = minb + self.maxb = maxb + + def narrow(self, other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth': + if not isinstance(other, TypeConstraintBitWidth): + raise Exception('Invalid comparison') + + if self.minb > other.maxb: + raise TypingNarrowProtoError('Min bitwidth exceeds other max bitwidth') + + if other.minb > self.maxb: + raise TypingNarrowProtoError('Other min bitwidth exceeds max bitwidth') + + return TypeConstraintBitWidth( + minb=max(self.minb, other.minb), + maxb=min(self.maxb, other.maxb), + ) + + def __repr__(self) -> str: + return f'BitWidth={self.minb}..{self.maxb}' + +class TypeVar: + def __init__(self, ctx: 'Context') -> None: + self.context = ctx + self.constraints: Dict[Type[TypeConstraintBase], TypeConstraintBase] = {} + self.locations: List[str] = [] + + def add_constraint(self, newconst: TypeConstraintBase) -> None: + if newconst.__class__ in self.constraints: + self.constraints[newconst.__class__] = self.constraints[newconst.__class__].narrow(newconst) + else: + self.constraints[newconst.__class__] = newconst + + def add_location(self, ref: str) -> None: + self.locations.append(ref) + + def __repr__(self) -> str: + return ( + 'TypeVar<' + + '; '.join(map(repr, self.constraints.values())) + + '; locations: ' + + ', '.join(self.locations) + + '>' + ) + +class Context: + def clone(self) -> 'Context': + return self # TODO: STUB + + def new_var(self) -> TypeVar: + return TypeVar(self) + + def unify(self, l: 'TypeVar', r: 'TypeVar') -> None: + newtypevar = self.new_var() + + try: + for const in l.constraints.values(): + newtypevar.add_constraint(const) + for const in r.constraints.values(): + newtypevar.add_constraint(const) + except TypingNarrowProtoError as ex: + raise TypingNarrowError(l, r, str(ex)) from None + + newtypevar.locations.extend(l.locations) + newtypevar.locations.extend(r.locations) + + # Make pointer locations to the constraints and locations + # so they get linked together throughout the unification + + l.constraints = newtypevar.constraints + l.locations = newtypevar.locations + + r.constraints = newtypevar.constraints + r.locations = newtypevar.locations + + return diff --git a/tests/integration/runners.py b/tests/integration/runners.py index fd3a53e..57d2adb 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, TextIO import ctypes import io +import warnings import pywasm.binary import wasm3 @@ -13,6 +14,7 @@ import wasmtime from phasm.compiler import phasm_compile from phasm.parser import phasm_parse +from phasm.typer import phasm_type from phasm import ourlang from phasm import wasm @@ -40,6 +42,10 @@ class RunnerBase: Parses the Phasm code into an AST """ self.phasm_ast = phasm_parse(self.phasm_code) + try: + phasm_type(self.phasm_ast) + except NotImplementedError as exc: + warnings.warn(f'phash_type throws an NotImplementedError on this test: {exc}') def compile_ast(self) -> None: """