idea: Actual type class

This commit is contained in:
Johan B.W. de Vries 2022-10-23 13:52:48 +02:00
parent 312f7949bd
commit 42c9ff6ca7
3 changed files with 128 additions and 137 deletions

View File

@ -56,14 +56,14 @@ def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
return wasm.WasmTypeFloat64()
assert inp is not None, typing.ASSERTION_ERROR
tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
tc_type = inp.get_type()
if tc_type is None:
raise NotImplementedError(type_var, inp)
if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# StaticArray, Tuples and Structs are passed as pointer
# And pointers are i32
return wasm.WasmTypeInt32()
# if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# # StaticArray, Tuples and Structs are passed as pointer
# # And pointers are i32
# return wasm.WasmTypeInt32()
# TODO: Broken after new type system
# if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
@ -187,8 +187,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp.variable, ourlang.ModuleConstantDef):
assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
tc_type = inp.variable.type_var.get_type()
if tc_type is None:
raise NotImplementedError(expression, inp, inp.variable.type_var)
# TODO: Broken after new type system
@ -200,11 +200,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
# return
#
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
assert inp.variable.data_block is not None, 'Combined values are memory stored'
assert inp.variable.data_block.address is not None, 'Value not allocated'
wgn.i32.const(inp.variable.data_block.address)
return
# if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# assert inp.variable.data_block is not None, 'Combined values are memory stored'
# assert inp.variable.data_block.address is not None, 'Value not allocated'
# wgn.i32.const(inp.variable.data_block.address)
# return
assert inp.variable.data_block is None, 'Primitives are not memory stored'
@ -314,47 +314,47 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp, ourlang.Subscript):
assert inp.varref.type_var is not None, typing.ASSERTION_ERROR
tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive)
if tc_prim is None:
tc_type = inp.varref.type_var.get_type()
if tc_type is None:
raise NotImplementedError(expression, inp, inp.varref.type_var)
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
if not isinstance(inp.index, ourlang.ConstantPrimitive):
raise NotImplementedError(expression, inp, inp.index)
if not isinstance(inp.index.value, int):
raise NotImplementedError(expression, inp, inp.index.value)
assert inp.type_var is not None, typing.ASSERTION_ERROR
mtyp = typing.simplify(inp.type_var)
if mtyp is None:
raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
if mtyp == 'u8':
# u8 operations are done using i32, since WASM does not have u8 operations
mtyp = 'i32'
elif mtyp == 'u32':
# u32 operations are done using i32, using _u operations
mtyp = 'i32'
elif mtyp == 'u64':
# u64 operations are done using i64, using _u operations
mtyp = 'i64'
tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
if tc_subs is None:
raise NotImplementedError(expression, inp, inp.varref.type_var)
assert 0 < len(tc_subs.members)
tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
if tc_bits is None or len(tc_bits.oneof) > 1:
raise NotImplementedError(expression, inp, inp.varref.type_var)
bitwidth = next(iter(tc_bits.oneof))
if bitwidth % 8 != 0:
raise NotImplementedError(expression, inp, inp.varref.type_var)
expression(wgn, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
return
# if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# if not isinstance(inp.index, ourlang.ConstantPrimitive):
# raise NotImplementedError(expression, inp, inp.index)
# if not isinstance(inp.index.value, int):
# raise NotImplementedError(expression, inp, inp.index.value)
#
# assert inp.type_var is not None, typing.ASSERTION_ERROR
# mtyp = typing.simplify(inp.type_var)
# if mtyp is None:
# raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
#
# if mtyp == 'u8':
# # u8 operations are done using i32, since WASM does not have u8 operations
# mtyp = 'i32'
# elif mtyp == 'u32':
# # u32 operations are done using i32, using _u operations
# mtyp = 'i32'
# elif mtyp == 'u64':
# # u64 operations are done using i64, using _u operations
# mtyp = 'i64'
#
# tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
# if tc_subs is None:
# raise NotImplementedError(expression, inp, inp.varref.type_var)
#
# assert 0 < len(tc_subs.members)
# tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
# if tc_bits is None or len(tc_bits.oneof) > 1:
# raise NotImplementedError(expression, inp, inp.varref.type_var)
#
# bitwidth = next(iter(tc_bits.oneof))
# if bitwidth % 8 != 0:
# raise NotImplementedError(expression, inp, inp.varref.type_var)
#
# expression(wgn, inp.varref)
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
# return
raise NotImplementedError(expression, inp, inp.varref.type_var)

View File

@ -6,7 +6,8 @@ from . import ourlang
from .exceptions import TypingError
from .typing import (
Context,
TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript,
PhasmTypeInteger, PhasmTypeReal,
TypeConstraintBitWidth, TypeConstraintSigned, TypeConstraintSubscript,
TypeVar,
from_str,
)
@ -19,7 +20,7 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
result = ctx.new_var()
if isinstance(inp.value, int):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.set_type(PhasmTypeInteger)
# Need at least this many bits to store this constant value
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2))
@ -34,7 +35,7 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
return result
if isinstance(inp.value, float):
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.set_type(PhasmTypeReal)
# We don't have fancy logic here to detect if the float constant
# fits in the given type. There a number of edge cases to consider,
@ -107,8 +108,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
return left
if inp.operator in ('<<', '>>', ):
inp.type_var = ctx.new_var()
inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
inp.type_var = ctx.new_var(PhasmTypeInteger)
inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, )))
inp.type_var.add_constraint(TypeConstraintSigned(False))

View File

@ -131,6 +131,18 @@ class TypeStruct(TypeBase):
# back to the AST. If so, we need to fix the typer.
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
class PhasmType:
__slots__ = ('name', )
name: str
def __init__(self, name: str) -> None:
self.name = name
def __repr__(self) -> str:
return 'PhasmType' + self.name
PhasmTypeInteger = PhasmType('Integer')
PhasmTypeReal = PhasmType('Real')
class TypingNarrowProtoError(TypingError):
"""
@ -156,38 +168,6 @@ class TypeConstraintBase:
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase):
"""
This contraint on a type defines its primitive shape
"""
__slots__ = ('primitive', )
class Primitive(enum.Enum):
"""
The primitive ID
"""
INT = 0
FLOAT = 1
STATIC_ARRAY = 10
primitive: Primitive
def __init__(self, primitive: Primitive) -> None:
self.primitive = primitive
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
if not isinstance(other, TypeConstraintPrimitive):
raise Exception('Invalid comparison')
if self.primitive != other.primitive:
raise TypingNarrowProtoError('Primitive does not match')
return TypeConstraintPrimitive(self.primitive)
def __repr__(self) -> str:
return f'Primitive={self.primitive.name}'
class TypeConstraintSigned(TypeConstraintBase):
"""
Contraint on whether a signed value can be used or not, or whether
@ -326,6 +306,13 @@ class TypeVar:
self.ctx = ctx
self.ctx_id = ctx_id
def get_type(self) -> Optional[PhasmType]:
return self.ctx.var_types[self.ctx_id]
def set_type(self, type_: PhasmType) -> None:
assert self.ctx.var_types[self.ctx_id] is None, 'Type already set'
self.ctx.var_types[self.ctx_id] = type_
def add_constraint(self, newconst: TypeConstraintBase) -> None:
csts = self.ctx.var_constraints[self.ctx_id]
@ -345,8 +332,12 @@ class TypeVar:
self.ctx.var_locations[self.ctx_id].add(ref)
def __repr__(self) -> str:
typ = self.ctx.var_types[self.ctx_id]
return (
'TypeVar<'
+ ('?' if typ is None else repr(typ))
+ '; '
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
+ '; locations: '
+ ', '.join(sorted(self.ctx.var_locations[self.ctx_id]))
@ -367,16 +358,18 @@ class Context:
# Store the TypeVar properties as a lookup
# so we can update these when unifying
self.var_types: Dict[int, Optional[PhasmType]] = {}
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
self.var_locations: Dict[int, Set[str]] = {}
def new_var(self) -> TypeVar:
def new_var(self, type_: Optional[PhasmType] = None) -> TypeVar:
ctx_id = self.next_ctx_id
self.next_ctx_id += 1
result = TypeVar(self, ctx_id)
self.vars_by_id[ctx_id] = [result]
self.var_types[ctx_id] = type_
self.var_constraints[ctx_id] = {}
self.var_locations[ctx_id] = set()
@ -393,12 +386,19 @@ class Context:
# Backup some values that we'll overwrite
l_ctx_id = l.ctx_id
r_ctx_id = r.ctx_id
l_type = self.var_types[l_ctx_id]
r_type = self.var_types[r_ctx_id]
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
# Create a new TypeVar, with the combined contraints
# and locations of the old ones
n = self.new_var()
if l_type is not None and r_type is not None and l_type != r_type:
raise TypingNarrowError(l, r, 'Type does not match')
else:
self.var_types[n.ctx_id] = l_type
try:
for const in self.var_constraints[l_ctx_id].values():
n.add_constraint(const)
@ -435,15 +435,13 @@ def simplify(inp: TypeVar) -> Optional[str]:
Should round trip with from_str
"""
tc_prim = inp.get_constraint(TypeConstraintPrimitive)
tc_bits = inp.get_constraint(TypeConstraintBitWidth)
tc_sign = inp.get_constraint(TypeConstraintSigned)
if tc_prim is None:
if inp.get_type() is None:
return None
primitive = tc_prim.primitive
if primitive is TypeConstraintPrimitive.Primitive.INT:
if inp.get_type() is PhasmTypeInteger:
if tc_bits is None or tc_sign is None:
return None
@ -457,7 +455,7 @@ def simplify(inp: TypeVar) -> Optional[str]:
base = 'i' if tc_sign.signed else 'u'
return f'{base}{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.FLOAT:
if inp.get_type() is PhasmTypeReal:
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
return None
@ -470,16 +468,16 @@ def simplify(inp: TypeVar) -> Optional[str]:
return f'f{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
tc_subs = inp.get_constraint(TypeConstraintSubscript)
assert tc_subs is not None
assert tc_subs.members
sab = simplify(tc_subs.members[0])
if sab is None:
return None
return f'{sab}[{len(tc_subs.members)}]'
# if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
# tc_subs = inp.get_constraint(TypeConstraintSubscript)
# assert tc_subs is not None
# assert tc_subs.members
#
# sab = simplify(tc_subs.members[0])
# if sab is None:
# return None
#
# return f'{sab}[{len(tc_subs.members)}]'
return None
@ -487,8 +485,7 @@ def make_u8(ctx: Context) -> TypeVar:
"""
Makes a u8 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u8')
@ -498,8 +495,7 @@ def make_u32(ctx: Context) -> TypeVar:
"""
Makes a u32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u32')
@ -509,8 +505,7 @@ def make_u64(ctx: Context) -> TypeVar:
"""
Makes a u64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location('u64')
@ -520,8 +515,7 @@ def make_i32(ctx: Context) -> TypeVar:
"""
Makes a i32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i32')
@ -531,8 +525,7 @@ def make_i64(ctx: Context) -> TypeVar:
"""
Makes a i64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result = ctx.new_var(PhasmTypeInteger)
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location('i64')
@ -542,8 +535,7 @@ def make_f32(ctx: Context) -> TypeVar:
"""
Makes a f32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result = ctx.new_var(PhasmTypeReal)
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location('f32')
return result
@ -552,8 +544,7 @@ def make_f64(ctx: Context) -> TypeVar:
"""
Makes a f64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result = ctx.new_var(PhasmTypeReal)
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location('f64')
return result
@ -588,23 +579,23 @@ def from_str(ctx: Context, inp: str, location: Optional[str] = None) -> TypeVar:
result.add_location(location)
return result
match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
if match:
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
result.add_constraint(TypeConstraintSubscript(members=(
# Make copies so they don't get entangled
# with each other.
from_str(ctx, match[1], match[1])
for _ in range(int(match[2]))
)))
result.add_location(inp)
if location is not None:
result.add_location(location)
return result
# match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
# if match:
# result = ctx.new_var()
#
# result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
# result.add_constraint(TypeConstraintSubscript(members=(
# # Make copies so they don't get entangled
# # with each other.
# from_str(ctx, match[1], match[1])
# for _ in range(int(match[2]))
# )))
#
# result.add_location(inp)
#
# if location is not None:
# result.add_location(location)
#
# return result
raise NotImplementedError(from_str, inp)