From 42c9ff6ca70a3a9fb6827a55168398d4fbebdb5e Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sun, 23 Oct 2022 13:52:48 +0200 Subject: [PATCH] idea: Actual type class --- phasm/compiler.py | 104 +++++++++++++++---------------- phasm/typer.py | 10 +-- phasm/typing.py | 151 ++++++++++++++++++++++------------------------ 3 files changed, 128 insertions(+), 137 deletions(-) diff --git a/phasm/compiler.py b/phasm/compiler.py index 22593b8..74b55d3 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -56,14 +56,14 @@ def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType: return wasm.WasmTypeFloat64() assert inp is not None, typing.ASSERTION_ERROR - tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive) - if tc_prim is None: + tc_type = inp.get_type() + if tc_type is None: raise NotImplementedError(type_var, inp) - if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: - # StaticArray, Tuples and Structs are passed as pointer - # And pointers are i32 - return wasm.WasmTypeInt32() + # if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + # # StaticArray, Tuples and Structs are passed as pointer + # # And pointers are i32 + # return wasm.WasmTypeInt32() # TODO: Broken after new type system # if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): @@ -187,8 +187,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: if isinstance(inp.variable, ourlang.ModuleConstantDef): assert inp.variable.type_var is not None, typing.ASSERTION_ERROR - tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive) - if tc_prim is None: + tc_type = inp.variable.type_var.get_type() + if tc_type is None: raise NotImplementedError(expression, inp, inp.variable.type_var) # TODO: Broken after new type system @@ -200,11 +200,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: # return # - if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: - assert inp.variable.data_block is not None, 'Combined values are memory stored' - assert inp.variable.data_block.address is not None, 'Value not allocated' - wgn.i32.const(inp.variable.data_block.address) - return + # if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + # assert inp.variable.data_block is not None, 'Combined values are memory stored' + # assert inp.variable.data_block.address is not None, 'Value not allocated' + # wgn.i32.const(inp.variable.data_block.address) + # return assert inp.variable.data_block is None, 'Primitives are not memory stored' @@ -314,47 +314,47 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: if isinstance(inp, ourlang.Subscript): assert inp.varref.type_var is not None, typing.ASSERTION_ERROR - tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive) - if tc_prim is None: + tc_type = inp.varref.type_var.get_type() + if tc_type is None: raise NotImplementedError(expression, inp, inp.varref.type_var) - if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: - if not isinstance(inp.index, ourlang.ConstantPrimitive): - raise NotImplementedError(expression, inp, inp.index) - if not isinstance(inp.index.value, int): - raise NotImplementedError(expression, inp, inp.index.value) - - assert inp.type_var is not None, typing.ASSERTION_ERROR - mtyp = typing.simplify(inp.type_var) - if mtyp is None: - raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp) - - if mtyp == 'u8': - # u8 operations are done using i32, since WASM does not have u8 operations - mtyp = 'i32' - elif mtyp == 'u32': - # u32 operations are done using i32, using _u operations - mtyp = 'i32' - elif mtyp == 'u64': - # u64 operations are done using i64, using _u operations - mtyp = 'i64' - - tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript) - if tc_subs is None: - raise NotImplementedError(expression, inp, inp.varref.type_var) - - assert 0 < len(tc_subs.members) - tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth) - if tc_bits is None or len(tc_bits.oneof) > 1: - raise NotImplementedError(expression, inp, inp.varref.type_var) - - bitwidth = next(iter(tc_bits.oneof)) - if bitwidth % 8 != 0: - raise NotImplementedError(expression, inp, inp.varref.type_var) - - expression(wgn, inp.varref) - wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value)) - return + # if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + # if not isinstance(inp.index, ourlang.ConstantPrimitive): + # raise NotImplementedError(expression, inp, inp.index) + # if not isinstance(inp.index.value, int): + # raise NotImplementedError(expression, inp, inp.index.value) + # + # assert inp.type_var is not None, typing.ASSERTION_ERROR + # mtyp = typing.simplify(inp.type_var) + # if mtyp is None: + # raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp) + # + # if mtyp == 'u8': + # # u8 operations are done using i32, since WASM does not have u8 operations + # mtyp = 'i32' + # elif mtyp == 'u32': + # # u32 operations are done using i32, using _u operations + # mtyp = 'i32' + # elif mtyp == 'u64': + # # u64 operations are done using i64, using _u operations + # mtyp = 'i64' + # + # tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript) + # if tc_subs is None: + # raise NotImplementedError(expression, inp, inp.varref.type_var) + # + # assert 0 < len(tc_subs.members) + # tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth) + # if tc_bits is None or len(tc_bits.oneof) > 1: + # raise NotImplementedError(expression, inp, inp.varref.type_var) + # + # bitwidth = next(iter(tc_bits.oneof)) + # if bitwidth % 8 != 0: + # raise NotImplementedError(expression, inp, inp.varref.type_var) + # + # expression(wgn, inp.varref) + # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value)) + # return raise NotImplementedError(expression, inp, inp.varref.type_var) diff --git a/phasm/typer.py b/phasm/typer.py index 5081dac..abfb370 100644 --- a/phasm/typer.py +++ b/phasm/typer.py @@ -6,7 +6,8 @@ from . import ourlang from .exceptions import TypingError from .typing import ( Context, - TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript, + PhasmTypeInteger, PhasmTypeReal, + TypeConstraintBitWidth, TypeConstraintSigned, TypeConstraintSubscript, TypeVar, from_str, ) @@ -19,7 +20,7 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar: result = ctx.new_var() if isinstance(inp.value, int): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result.set_type(PhasmTypeInteger) # Need at least this many bits to store this constant value result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) @@ -34,7 +35,7 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar: return result if isinstance(inp.value, float): - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result.set_type(PhasmTypeReal) # We don't have fancy logic here to detect if the float constant # fits in the given type. There a number of edge cases to consider, @@ -107,8 +108,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar': return left if inp.operator in ('<<', '>>', ): - inp.type_var = ctx.new_var() - inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + inp.type_var = ctx.new_var(PhasmTypeInteger) inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, ))) inp.type_var.add_constraint(TypeConstraintSigned(False)) diff --git a/phasm/typing.py b/phasm/typing.py index 468e537..c9ab59a 100644 --- a/phasm/typing.py +++ b/phasm/typing.py @@ -131,6 +131,18 @@ class TypeStruct(TypeBase): # back to the AST. If so, we need to fix the typer. ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method' +class PhasmType: + __slots__ = ('name', ) + name: str + + def __init__(self, name: str) -> None: + self.name = name + + def __repr__(self) -> str: + return 'PhasmType' + self.name + +PhasmTypeInteger = PhasmType('Integer') +PhasmTypeReal = PhasmType('Real') class TypingNarrowProtoError(TypingError): """ @@ -156,38 +168,6 @@ class TypeConstraintBase: def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase': raise NotImplementedError('narrow', self, other) -class TypeConstraintPrimitive(TypeConstraintBase): - """ - This contraint on a type defines its primitive shape - """ - __slots__ = ('primitive', ) - - class Primitive(enum.Enum): - """ - The primitive ID - """ - INT = 0 - FLOAT = 1 - - STATIC_ARRAY = 10 - - primitive: Primitive - - def __init__(self, primitive: Primitive) -> None: - self.primitive = primitive - - def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive': - if not isinstance(other, TypeConstraintPrimitive): - raise Exception('Invalid comparison') - - if self.primitive != other.primitive: - raise TypingNarrowProtoError('Primitive does not match') - - return TypeConstraintPrimitive(self.primitive) - - def __repr__(self) -> str: - return f'Primitive={self.primitive.name}' - class TypeConstraintSigned(TypeConstraintBase): """ Contraint on whether a signed value can be used or not, or whether @@ -326,6 +306,13 @@ class TypeVar: self.ctx = ctx self.ctx_id = ctx_id + def get_type(self) -> Optional[PhasmType]: + return self.ctx.var_types[self.ctx_id] + + def set_type(self, type_: PhasmType) -> None: + assert self.ctx.var_types[self.ctx_id] is None, 'Type already set' + self.ctx.var_types[self.ctx_id] = type_ + def add_constraint(self, newconst: TypeConstraintBase) -> None: csts = self.ctx.var_constraints[self.ctx_id] @@ -345,8 +332,12 @@ class TypeVar: self.ctx.var_locations[self.ctx_id].add(ref) def __repr__(self) -> str: + typ = self.ctx.var_types[self.ctx_id] + return ( 'TypeVar<' + + ('?' if typ is None else repr(typ)) + + '; ' + '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values())) + '; locations: ' + ', '.join(sorted(self.ctx.var_locations[self.ctx_id])) @@ -367,16 +358,18 @@ class Context: # Store the TypeVar properties as a lookup # so we can update these when unifying + self.var_types: Dict[int, Optional[PhasmType]] = {} self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {} self.var_locations: Dict[int, Set[str]] = {} - def new_var(self) -> TypeVar: + def new_var(self, type_: Optional[PhasmType] = None) -> TypeVar: ctx_id = self.next_ctx_id self.next_ctx_id += 1 result = TypeVar(self, ctx_id) self.vars_by_id[ctx_id] = [result] + self.var_types[ctx_id] = type_ self.var_constraints[ctx_id] = {} self.var_locations[ctx_id] = set() @@ -393,12 +386,19 @@ class Context: # Backup some values that we'll overwrite l_ctx_id = l.ctx_id r_ctx_id = r.ctx_id + l_type = self.var_types[l_ctx_id] + r_type = self.var_types[r_ctx_id] l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id] # Create a new TypeVar, with the combined contraints # and locations of the old ones n = self.new_var() + if l_type is not None and r_type is not None and l_type != r_type: + raise TypingNarrowError(l, r, 'Type does not match') + else: + self.var_types[n.ctx_id] = l_type + try: for const in self.var_constraints[l_ctx_id].values(): n.add_constraint(const) @@ -435,15 +435,13 @@ def simplify(inp: TypeVar) -> Optional[str]: Should round trip with from_str """ - tc_prim = inp.get_constraint(TypeConstraintPrimitive) tc_bits = inp.get_constraint(TypeConstraintBitWidth) tc_sign = inp.get_constraint(TypeConstraintSigned) - if tc_prim is None: + if inp.get_type() is None: return None - primitive = tc_prim.primitive - if primitive is TypeConstraintPrimitive.Primitive.INT: + if inp.get_type() is PhasmTypeInteger: if tc_bits is None or tc_sign is None: return None @@ -457,7 +455,7 @@ def simplify(inp: TypeVar) -> Optional[str]: base = 'i' if tc_sign.signed else 'u' return f'{base}{bitwidth}' - if primitive is TypeConstraintPrimitive.Primitive.FLOAT: + if inp.get_type() is PhasmTypeReal: if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint return None @@ -470,16 +468,16 @@ def simplify(inp: TypeVar) -> Optional[str]: return f'f{bitwidth}' - if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY: - tc_subs = inp.get_constraint(TypeConstraintSubscript) - assert tc_subs is not None - assert tc_subs.members - - sab = simplify(tc_subs.members[0]) - if sab is None: - return None - - return f'{sab}[{len(tc_subs.members)}]' + # if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY: + # tc_subs = inp.get_constraint(TypeConstraintSubscript) + # assert tc_subs is not None + # assert tc_subs.members + # + # sab = simplify(tc_subs.members[0]) + # if sab is None: + # return None + # + # return f'{sab}[{len(tc_subs.members)}]' return None @@ -487,8 +485,7 @@ def make_u8(ctx: Context) -> TypeVar: """ Makes a u8 TypeVar """ - result = ctx.new_var() - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result = ctx.new_var(PhasmTypeInteger) result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintSigned(False)) result.add_location('u8') @@ -498,8 +495,7 @@ def make_u32(ctx: Context) -> TypeVar: """ Makes a u32 TypeVar """ - result = ctx.new_var() - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result = ctx.new_var(PhasmTypeInteger) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(False)) result.add_location('u32') @@ -509,8 +505,7 @@ def make_u64(ctx: Context) -> TypeVar: """ Makes a u64 TypeVar """ - result = ctx.new_var() - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result = ctx.new_var(PhasmTypeInteger) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(False)) result.add_location('u64') @@ -520,8 +515,7 @@ def make_i32(ctx: Context) -> TypeVar: """ Makes a i32 TypeVar """ - result = ctx.new_var() - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result = ctx.new_var(PhasmTypeInteger) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintSigned(True)) result.add_location('i32') @@ -531,8 +525,7 @@ def make_i64(ctx: Context) -> TypeVar: """ Makes a i64 TypeVar """ - result = ctx.new_var() - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) + result = ctx.new_var(PhasmTypeInteger) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintSigned(True)) result.add_location('i64') @@ -542,8 +535,7 @@ def make_f32(ctx: Context) -> TypeVar: """ Makes a f32 TypeVar """ - result = ctx.new_var() - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result = ctx.new_var(PhasmTypeReal) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_location('f32') return result @@ -552,8 +544,7 @@ def make_f64(ctx: Context) -> TypeVar: """ Makes a f64 TypeVar """ - result = ctx.new_var() - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) + result = ctx.new_var(PhasmTypeReal) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_location('f64') return result @@ -588,23 +579,23 @@ def from_str(ctx: Context, inp: str, location: Optional[str] = None) -> TypeVar: result.add_location(location) return result - match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp) - if match: - result = ctx.new_var() - - result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY)) - result.add_constraint(TypeConstraintSubscript(members=( - # Make copies so they don't get entangled - # with each other. - from_str(ctx, match[1], match[1]) - for _ in range(int(match[2])) - ))) - - result.add_location(inp) - - if location is not None: - result.add_location(location) - - return result + # match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp) + # if match: + # result = ctx.new_var() + # + # result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY)) + # result.add_constraint(TypeConstraintSubscript(members=( + # # Make copies so they don't get entangled + # # with each other. + # from_str(ctx, match[1], match[1]) + # for _ in range(int(match[2])) + # ))) + # + # result.add_location(inp) + # + # if location is not None: + # result.add_location(location) + # + # return result raise NotImplementedError(from_str, inp)