idea: Actual type class

This commit is contained in:
Johan B.W. de Vries 2022-10-23 13:52:48 +02:00
parent 312f7949bd
commit 42c9ff6ca7
3 changed files with 128 additions and 137 deletions

View File

@ -56,14 +56,14 @@ def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
return wasm.WasmTypeFloat64() return wasm.WasmTypeFloat64()
assert inp is not None, typing.ASSERTION_ERROR assert inp is not None, typing.ASSERTION_ERROR
tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive) tc_type = inp.get_type()
if tc_prim is None: if tc_type is None:
raise NotImplementedError(type_var, inp) raise NotImplementedError(type_var, inp)
if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: # if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# StaticArray, Tuples and Structs are passed as pointer # # StaticArray, Tuples and Structs are passed as pointer
# And pointers are i32 # # And pointers are i32
return wasm.WasmTypeInt32() # return wasm.WasmTypeInt32()
# TODO: Broken after new type system # TODO: Broken after new type system
# if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)): # if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
@ -187,8 +187,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp.variable, ourlang.ModuleConstantDef): if isinstance(inp.variable, ourlang.ModuleConstantDef):
assert inp.variable.type_var is not None, typing.ASSERTION_ERROR assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive) tc_type = inp.variable.type_var.get_type()
if tc_prim is None: if tc_type is None:
raise NotImplementedError(expression, inp, inp.variable.type_var) raise NotImplementedError(expression, inp, inp.variable.type_var)
# TODO: Broken after new type system # TODO: Broken after new type system
@ -200,11 +200,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# return # return
# #
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: # if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
assert inp.variable.data_block is not None, 'Combined values are memory stored' # assert inp.variable.data_block is not None, 'Combined values are memory stored'
assert inp.variable.data_block.address is not None, 'Value not allocated' # assert inp.variable.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.variable.data_block.address) # wgn.i32.const(inp.variable.data_block.address)
return # return
assert inp.variable.data_block is None, 'Primitives are not memory stored' assert inp.variable.data_block is None, 'Primitives are not memory stored'
@ -314,47 +314,47 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp, ourlang.Subscript): if isinstance(inp, ourlang.Subscript):
assert inp.varref.type_var is not None, typing.ASSERTION_ERROR assert inp.varref.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive) tc_type = inp.varref.type_var.get_type()
if tc_prim is None: if tc_type is None:
raise NotImplementedError(expression, inp, inp.varref.type_var) raise NotImplementedError(expression, inp, inp.varref.type_var)
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY: # if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
if not isinstance(inp.index, ourlang.ConstantPrimitive): # if not isinstance(inp.index, ourlang.ConstantPrimitive):
raise NotImplementedError(expression, inp, inp.index) # raise NotImplementedError(expression, inp, inp.index)
if not isinstance(inp.index.value, int): # if not isinstance(inp.index.value, int):
raise NotImplementedError(expression, inp, inp.index.value) # raise NotImplementedError(expression, inp, inp.index.value)
#
assert inp.type_var is not None, typing.ASSERTION_ERROR # assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var) # mtyp = typing.simplify(inp.type_var)
if mtyp is None: # if mtyp is None:
raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp) # raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
#
if mtyp == 'u8': # if mtyp == 'u8':
# u8 operations are done using i32, since WASM does not have u8 operations # # u8 operations are done using i32, since WASM does not have u8 operations
mtyp = 'i32' # mtyp = 'i32'
elif mtyp == 'u32': # elif mtyp == 'u32':
# u32 operations are done using i32, using _u operations # # u32 operations are done using i32, using _u operations
mtyp = 'i32' # mtyp = 'i32'
elif mtyp == 'u64': # elif mtyp == 'u64':
# u64 operations are done using i64, using _u operations # # u64 operations are done using i64, using _u operations
mtyp = 'i64' # mtyp = 'i64'
#
tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript) # tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
if tc_subs is None: # if tc_subs is None:
raise NotImplementedError(expression, inp, inp.varref.type_var) # raise NotImplementedError(expression, inp, inp.varref.type_var)
#
assert 0 < len(tc_subs.members) # assert 0 < len(tc_subs.members)
tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth) # tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
if tc_bits is None or len(tc_bits.oneof) > 1: # if tc_bits is None or len(tc_bits.oneof) > 1:
raise NotImplementedError(expression, inp, inp.varref.type_var) # raise NotImplementedError(expression, inp, inp.varref.type_var)
#
bitwidth = next(iter(tc_bits.oneof)) # bitwidth = next(iter(tc_bits.oneof))
if bitwidth % 8 != 0: # if bitwidth % 8 != 0:
raise NotImplementedError(expression, inp, inp.varref.type_var) # raise NotImplementedError(expression, inp, inp.varref.type_var)
#
expression(wgn, inp.varref) # expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value)) # wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
return # return
raise NotImplementedError(expression, inp, inp.varref.type_var) raise NotImplementedError(expression, inp, inp.varref.type_var)

View File

@ -6,7 +6,8 @@ from . import ourlang
from .exceptions import TypingError from .exceptions import TypingError
from .typing import ( from .typing import (
Context, Context,
TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript, PhasmTypeInteger, PhasmTypeReal,
TypeConstraintBitWidth, TypeConstraintSigned, TypeConstraintSubscript,
TypeVar, TypeVar,
from_str, from_str,
) )
@ -19,7 +20,7 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
result = ctx.new_var() result = ctx.new_var()
if isinstance(inp.value, int): if isinstance(inp.value, int):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT)) result.set_type(PhasmTypeInteger)
# Need at least this many bits to store this constant value # Need at least this many bits to store this constant value
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2)) result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2))
@ -34,7 +35,7 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
return result return result
if isinstance(inp.value, float): if isinstance(inp.value, float):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT)) result.set_type(PhasmTypeReal)
# We don't have fancy logic here to detect if the float constant # We don't have fancy logic here to detect if the float constant
# fits in the given type. There a number of edge cases to consider, # fits in the given type. There a number of edge cases to consider,
@ -107,8 +108,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
return left return left
if inp.operator in ('<<', '>>', ): if inp.operator in ('<<', '>>', ):
inp.type_var = ctx.new_var() inp.type_var = ctx.new_var(PhasmTypeInteger)
inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, ))) inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, )))
inp.type_var.add_constraint(TypeConstraintSigned(False)) inp.type_var.add_constraint(TypeConstraintSigned(False))

View File

@ -131,6 +131,18 @@ class TypeStruct(TypeBase):
# back to the AST. If so, we need to fix the typer. # back to the AST. If so, we need to fix the typer.
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method' ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
class PhasmType:
__slots__ = ('name', )
name: str
def __init__(self, name: str) -> None:
self.name = name
def __repr__(self) -> str:
return 'PhasmType' + self.name
PhasmTypeInteger = PhasmType('Integer')
PhasmTypeReal = PhasmType('Real')
class TypingNarrowProtoError(TypingError): class TypingNarrowProtoError(TypingError):
""" """
@ -156,38 +168,6 @@ class TypeConstraintBase:
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase': def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other) raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase):
"""
This contraint on a type defines its primitive shape
"""
__slots__ = ('primitive', )
class Primitive(enum.Enum):
"""
The primitive ID
"""
INT = 0
FLOAT = 1
STATIC_ARRAY = 10
primitive: Primitive
def __init__(self, primitive: Primitive) -> None:
self.primitive = primitive
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
if not isinstance(other, TypeConstraintPrimitive):
raise Exception('Invalid comparison')
if self.primitive != other.primitive:
raise TypingNarrowProtoError('Primitive does not match')
return TypeConstraintPrimitive(self.primitive)
def __repr__(self) -> str:
return f'Primitive={self.primitive.name}'
class TypeConstraintSigned(TypeConstraintBase): class TypeConstraintSigned(TypeConstraintBase):
""" """
Contraint on whether a signed value can be used or not, or whether Contraint on whether a signed value can be used or not, or whether
@ -326,6 +306,13 @@ class TypeVar:
self.ctx = ctx self.ctx = ctx
self.ctx_id = ctx_id self.ctx_id = ctx_id
def get_type(self) -> Optional[PhasmType]:
return self.ctx.var_types[self.ctx_id]
def set_type(self, type_: PhasmType) -> None:
assert self.ctx.var_types[self.ctx_id] is None, 'Type already set'
self.ctx.var_types[self.ctx_id] = type_
def add_constraint(self, newconst: TypeConstraintBase) -> None: def add_constraint(self, newconst: TypeConstraintBase) -> None:
csts = self.ctx.var_constraints[self.ctx_id] csts = self.ctx.var_constraints[self.ctx_id]
@ -345,8 +332,12 @@ class TypeVar:
self.ctx.var_locations[self.ctx_id].add(ref) self.ctx.var_locations[self.ctx_id].add(ref)
def __repr__(self) -> str: def __repr__(self) -> str:
typ = self.ctx.var_types[self.ctx_id]
return ( return (
'TypeVar<' 'TypeVar<'
+ ('?' if typ is None else repr(typ))
+ '; '
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values())) + '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
+ '; locations: ' + '; locations: '
+ ', '.join(sorted(self.ctx.var_locations[self.ctx_id])) + ', '.join(sorted(self.ctx.var_locations[self.ctx_id]))
@ -367,16 +358,18 @@ class Context:
# Store the TypeVar properties as a lookup # Store the TypeVar properties as a lookup
# so we can update these when unifying # so we can update these when unifying
self.var_types: Dict[int, Optional[PhasmType]] = {}
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {} self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
self.var_locations: Dict[int, Set[str]] = {} self.var_locations: Dict[int, Set[str]] = {}
def new_var(self) -> TypeVar: def new_var(self, type_: Optional[PhasmType] = None) -> TypeVar:
ctx_id = self.next_ctx_id ctx_id = self.next_ctx_id
self.next_ctx_id += 1 self.next_ctx_id += 1
result = TypeVar(self, ctx_id) result = TypeVar(self, ctx_id)
self.vars_by_id[ctx_id] = [result] self.vars_by_id[ctx_id] = [result]
self.var_types[ctx_id] = type_
self.var_constraints[ctx_id] = {} self.var_constraints[ctx_id] = {}
self.var_locations[ctx_id] = set() self.var_locations[ctx_id] = set()
@ -393,12 +386,19 @@ class Context:
# Backup some values that we'll overwrite # Backup some values that we'll overwrite
l_ctx_id = l.ctx_id l_ctx_id = l.ctx_id
r_ctx_id = r.ctx_id r_ctx_id = r.ctx_id
l_type = self.var_types[l_ctx_id]
r_type = self.var_types[r_ctx_id]
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id] l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
# Create a new TypeVar, with the combined contraints # Create a new TypeVar, with the combined contraints
# and locations of the old ones # and locations of the old ones
n = self.new_var() n = self.new_var()
if l_type is not None and r_type is not None and l_type != r_type:
raise TypingNarrowError(l, r, 'Type does not match')
else:
self.var_types[n.ctx_id] = l_type
try: try:
for const in self.var_constraints[l_ctx_id].values(): for const in self.var_constraints[l_ctx_id].values():
n.add_constraint(const) n.add_constraint(const)
@ -435,15 +435,13 @@ def simplify(inp: TypeVar) -> Optional[str]:
Should round trip with from_str Should round trip with from_str
""" """
tc_prim = inp.get_constraint(TypeConstraintPrimitive)
tc_bits = inp.get_constraint(TypeConstraintBitWidth) tc_bits = inp.get_constraint(TypeConstraintBitWidth)
tc_sign = inp.get_constraint(TypeConstraintSigned) tc_sign = inp.get_constraint(TypeConstraintSigned)
if tc_prim is None: if inp.get_type() is None:
return None return None
primitive = tc_prim.primitive if inp.get_type() is PhasmTypeInteger:
if primitive is TypeConstraintPrimitive.Primitive.INT:
if tc_bits is None or tc_sign is None: if tc_bits is None or tc_sign is None:
return None return None
@ -457,7 +455,7 @@ def simplify(inp: TypeVar) -> Optional[str]:
base = 'i' if tc_sign.signed else 'u' base = 'i' if tc_sign.signed else 'u'
return f'{base}{bitwidth}' return f'{base}{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.FLOAT: if inp.get_type() is PhasmTypeReal:
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
return None return None
@ -470,16 +468,16 @@ def simplify(inp: TypeVar) -> Optional[str]:
return f'f{bitwidth}' return f'f{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY: # if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
tc_subs = inp.get_constraint(TypeConstraintSubscript) # tc_subs = inp.get_constraint(TypeConstraintSubscript)
assert tc_subs is not None # assert tc_subs is not None
assert tc_subs.members # assert tc_subs.members
#
sab = simplify(tc_subs.members[0]) # sab = simplify(tc_subs.members[0])
if sab is None: # if sab is None:
return None # return None
#
return f'{sab}[{len(tc_subs.members)}]' # return f'{sab}[{len(tc_subs.members)}]'
return None return None
@ -487,8 +485,7 @@ def make_u8(ctx: Context) -> TypeVar:
""" """
Makes a u8 TypeVar Makes a u8 TypeVar
""" """
result = ctx.new_var() result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8)) result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location('u8') result.add_location('u8')
@ -498,8 +495,7 @@ def make_u32(ctx: Context) -> TypeVar:
""" """
Makes a u32 TypeVar Makes a u32 TypeVar
""" """
result = ctx.new_var() result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location('u32') result.add_location('u32')
@ -509,8 +505,7 @@ def make_u64(ctx: Context) -> TypeVar:
""" """
Makes a u64 TypeVar Makes a u64 TypeVar
""" """
result = ctx.new_var() result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False)) result.add_constraint(TypeConstraintSigned(False))
result.add_location('u64') result.add_location('u64')
@ -520,8 +515,7 @@ def make_i32(ctx: Context) -> TypeVar:
""" """
Makes a i32 TypeVar Makes a i32 TypeVar
""" """
result = ctx.new_var() result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True)) result.add_constraint(TypeConstraintSigned(True))
result.add_location('i32') result.add_location('i32')
@ -531,8 +525,7 @@ def make_i64(ctx: Context) -> TypeVar:
""" """
Makes a i64 TypeVar Makes a i64 TypeVar
""" """
result = ctx.new_var() result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True)) result.add_constraint(TypeConstraintSigned(True))
result.add_location('i64') result.add_location('i64')
@ -542,8 +535,7 @@ def make_f32(ctx: Context) -> TypeVar:
""" """
Makes a f32 TypeVar Makes a f32 TypeVar
""" """
result = ctx.new_var() result = ctx.new_var(PhasmTypeReal)
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32)) result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location('f32') result.add_location('f32')
return result return result
@ -552,8 +544,7 @@ def make_f64(ctx: Context) -> TypeVar:
""" """
Makes a f64 TypeVar Makes a f64 TypeVar
""" """
result = ctx.new_var() result = ctx.new_var(PhasmTypeReal)
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64)) result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location('f64') result.add_location('f64')
return result return result
@ -588,23 +579,23 @@ def from_str(ctx: Context, inp: str, location: Optional[str] = None) -> TypeVar:
result.add_location(location) result.add_location(location)
return result return result
match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp) # match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
if match: # if match:
result = ctx.new_var() # result = ctx.new_var()
#
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY)) # result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
result.add_constraint(TypeConstraintSubscript(members=( # result.add_constraint(TypeConstraintSubscript(members=(
# Make copies so they don't get entangled # # Make copies so they don't get entangled
# with each other. # # with each other.
from_str(ctx, match[1], match[1]) # from_str(ctx, match[1], match[1])
for _ in range(int(match[2])) # for _ in range(int(match[2]))
))) # )))
#
result.add_location(inp) # result.add_location(inp)
#
if location is not None: # if location is not None:
result.add_location(location) # result.add_location(location)
#
return result # return result
raise NotImplementedError(from_str, inp) raise NotImplementedError(from_str, inp)