idea: Actual type class
This commit is contained in:
parent
312f7949bd
commit
42c9ff6ca7
@ -56,14 +56,14 @@ def type_var(inp: Optional[typing.TypeVar]) -> wasm.WasmType:
|
||||
return wasm.WasmTypeFloat64()
|
||||
|
||||
assert inp is not None, typing.ASSERTION_ERROR
|
||||
tc_prim = inp.get_constraint(typing.TypeConstraintPrimitive)
|
||||
if tc_prim is None:
|
||||
tc_type = inp.get_type()
|
||||
if tc_type is None:
|
||||
raise NotImplementedError(type_var, inp)
|
||||
|
||||
if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
||||
# StaticArray, Tuples and Structs are passed as pointer
|
||||
# And pointers are i32
|
||||
return wasm.WasmTypeInt32()
|
||||
# if tc_prim.primitive is typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
||||
# # StaticArray, Tuples and Structs are passed as pointer
|
||||
# # And pointers are i32
|
||||
# return wasm.WasmTypeInt32()
|
||||
|
||||
# TODO: Broken after new type system
|
||||
# if isinstance(inp, (typing.TypeStruct, typing.TypeTuple, typing.TypeStaticArray, typing.TypeBytes)):
|
||||
@ -187,8 +187,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
|
||||
if isinstance(inp.variable, ourlang.ModuleConstantDef):
|
||||
assert inp.variable.type_var is not None, typing.ASSERTION_ERROR
|
||||
tc_prim = inp.variable.type_var.get_constraint(typing.TypeConstraintPrimitive)
|
||||
if tc_prim is None:
|
||||
tc_type = inp.variable.type_var.get_type()
|
||||
if tc_type is None:
|
||||
raise NotImplementedError(expression, inp, inp.variable.type_var)
|
||||
|
||||
# TODO: Broken after new type system
|
||||
@ -200,11 +200,11 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
# return
|
||||
#
|
||||
|
||||
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
||||
assert inp.variable.data_block is not None, 'Combined values are memory stored'
|
||||
assert inp.variable.data_block.address is not None, 'Value not allocated'
|
||||
wgn.i32.const(inp.variable.data_block.address)
|
||||
return
|
||||
# if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
||||
# assert inp.variable.data_block is not None, 'Combined values are memory stored'
|
||||
# assert inp.variable.data_block.address is not None, 'Value not allocated'
|
||||
# wgn.i32.const(inp.variable.data_block.address)
|
||||
# return
|
||||
|
||||
assert inp.variable.data_block is None, 'Primitives are not memory stored'
|
||||
|
||||
@ -314,47 +314,47 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
|
||||
if isinstance(inp, ourlang.Subscript):
|
||||
assert inp.varref.type_var is not None, typing.ASSERTION_ERROR
|
||||
tc_prim = inp.varref.type_var.get_constraint(typing.TypeConstraintPrimitive)
|
||||
if tc_prim is None:
|
||||
tc_type = inp.varref.type_var.get_type()
|
||||
if tc_type is None:
|
||||
raise NotImplementedError(expression, inp, inp.varref.type_var)
|
||||
|
||||
if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
||||
if not isinstance(inp.index, ourlang.ConstantPrimitive):
|
||||
raise NotImplementedError(expression, inp, inp.index)
|
||||
if not isinstance(inp.index.value, int):
|
||||
raise NotImplementedError(expression, inp, inp.index.value)
|
||||
|
||||
assert inp.type_var is not None, typing.ASSERTION_ERROR
|
||||
mtyp = typing.simplify(inp.type_var)
|
||||
if mtyp is None:
|
||||
raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
|
||||
|
||||
if mtyp == 'u8':
|
||||
# u8 operations are done using i32, since WASM does not have u8 operations
|
||||
mtyp = 'i32'
|
||||
elif mtyp == 'u32':
|
||||
# u32 operations are done using i32, using _u operations
|
||||
mtyp = 'i32'
|
||||
elif mtyp == 'u64':
|
||||
# u64 operations are done using i64, using _u operations
|
||||
mtyp = 'i64'
|
||||
|
||||
tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
|
||||
if tc_subs is None:
|
||||
raise NotImplementedError(expression, inp, inp.varref.type_var)
|
||||
|
||||
assert 0 < len(tc_subs.members)
|
||||
tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
|
||||
if tc_bits is None or len(tc_bits.oneof) > 1:
|
||||
raise NotImplementedError(expression, inp, inp.varref.type_var)
|
||||
|
||||
bitwidth = next(iter(tc_bits.oneof))
|
||||
if bitwidth % 8 != 0:
|
||||
raise NotImplementedError(expression, inp, inp.varref.type_var)
|
||||
|
||||
expression(wgn, inp.varref)
|
||||
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
|
||||
return
|
||||
# if tc_prim.primitive == typing.TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
||||
# if not isinstance(inp.index, ourlang.ConstantPrimitive):
|
||||
# raise NotImplementedError(expression, inp, inp.index)
|
||||
# if not isinstance(inp.index.value, int):
|
||||
# raise NotImplementedError(expression, inp, inp.index.value)
|
||||
#
|
||||
# assert inp.type_var is not None, typing.ASSERTION_ERROR
|
||||
# mtyp = typing.simplify(inp.type_var)
|
||||
# if mtyp is None:
|
||||
# raise NotImplementedError(expression, inp, inp.varref.type_var, mtyp)
|
||||
#
|
||||
# if mtyp == 'u8':
|
||||
# # u8 operations are done using i32, since WASM does not have u8 operations
|
||||
# mtyp = 'i32'
|
||||
# elif mtyp == 'u32':
|
||||
# # u32 operations are done using i32, using _u operations
|
||||
# mtyp = 'i32'
|
||||
# elif mtyp == 'u64':
|
||||
# # u64 operations are done using i64, using _u operations
|
||||
# mtyp = 'i64'
|
||||
#
|
||||
# tc_subs = inp.varref.type_var.get_constraint(typing.TypeConstraintSubscript)
|
||||
# if tc_subs is None:
|
||||
# raise NotImplementedError(expression, inp, inp.varref.type_var)
|
||||
#
|
||||
# assert 0 < len(tc_subs.members)
|
||||
# tc_bits = tc_subs.members[0].get_constraint(typing.TypeConstraintBitWidth)
|
||||
# if tc_bits is None or len(tc_bits.oneof) > 1:
|
||||
# raise NotImplementedError(expression, inp, inp.varref.type_var)
|
||||
#
|
||||
# bitwidth = next(iter(tc_bits.oneof))
|
||||
# if bitwidth % 8 != 0:
|
||||
# raise NotImplementedError(expression, inp, inp.varref.type_var)
|
||||
#
|
||||
# expression(wgn, inp.varref)
|
||||
# wgn.add_statement(f'{mtyp}.load', 'offset=' + str(bitwidth // 8 * inp.index.value))
|
||||
# return
|
||||
|
||||
raise NotImplementedError(expression, inp, inp.varref.type_var)
|
||||
|
||||
|
||||
@ -6,7 +6,8 @@ from . import ourlang
|
||||
from .exceptions import TypingError
|
||||
from .typing import (
|
||||
Context,
|
||||
TypeConstraintBitWidth, TypeConstraintPrimitive, TypeConstraintSigned, TypeConstraintSubscript,
|
||||
PhasmTypeInteger, PhasmTypeReal,
|
||||
TypeConstraintBitWidth, TypeConstraintSigned, TypeConstraintSubscript,
|
||||
TypeVar,
|
||||
from_str,
|
||||
)
|
||||
@ -19,7 +20,7 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
|
||||
result = ctx.new_var()
|
||||
|
||||
if isinstance(inp.value, int):
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
||||
result.set_type(PhasmTypeInteger)
|
||||
|
||||
# Need at least this many bits to store this constant value
|
||||
result.add_constraint(TypeConstraintBitWidth(minb=len(bin(inp.value)) - 2))
|
||||
@ -34,7 +35,7 @@ def constant(ctx: Context, inp: ourlang.Constant) -> TypeVar:
|
||||
return result
|
||||
|
||||
if isinstance(inp.value, float):
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
|
||||
result.set_type(PhasmTypeReal)
|
||||
|
||||
# We don't have fancy logic here to detect if the float constant
|
||||
# fits in the given type. There a number of edge cases to consider,
|
||||
@ -107,8 +108,7 @@ def expression(ctx: Context, inp: ourlang.Expression) -> 'TypeVar':
|
||||
return left
|
||||
|
||||
if inp.operator in ('<<', '>>', ):
|
||||
inp.type_var = ctx.new_var()
|
||||
inp.type_var.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
||||
inp.type_var = ctx.new_var(PhasmTypeInteger)
|
||||
inp.type_var.add_constraint(TypeConstraintBitWidth(oneof=(32, 64, )))
|
||||
inp.type_var.add_constraint(TypeConstraintSigned(False))
|
||||
|
||||
|
||||
151
phasm/typing.py
151
phasm/typing.py
@ -131,6 +131,18 @@ class TypeStruct(TypeBase):
|
||||
# back to the AST. If so, we need to fix the typer.
|
||||
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
|
||||
|
||||
class PhasmType:
|
||||
__slots__ = ('name', )
|
||||
name: str
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return 'PhasmType' + self.name
|
||||
|
||||
PhasmTypeInteger = PhasmType('Integer')
|
||||
PhasmTypeReal = PhasmType('Real')
|
||||
|
||||
class TypingNarrowProtoError(TypingError):
|
||||
"""
|
||||
@ -156,38 +168,6 @@ class TypeConstraintBase:
|
||||
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase':
|
||||
raise NotImplementedError('narrow', self, other)
|
||||
|
||||
class TypeConstraintPrimitive(TypeConstraintBase):
|
||||
"""
|
||||
This contraint on a type defines its primitive shape
|
||||
"""
|
||||
__slots__ = ('primitive', )
|
||||
|
||||
class Primitive(enum.Enum):
|
||||
"""
|
||||
The primitive ID
|
||||
"""
|
||||
INT = 0
|
||||
FLOAT = 1
|
||||
|
||||
STATIC_ARRAY = 10
|
||||
|
||||
primitive: Primitive
|
||||
|
||||
def __init__(self, primitive: Primitive) -> None:
|
||||
self.primitive = primitive
|
||||
|
||||
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
|
||||
if not isinstance(other, TypeConstraintPrimitive):
|
||||
raise Exception('Invalid comparison')
|
||||
|
||||
if self.primitive != other.primitive:
|
||||
raise TypingNarrowProtoError('Primitive does not match')
|
||||
|
||||
return TypeConstraintPrimitive(self.primitive)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'Primitive={self.primitive.name}'
|
||||
|
||||
class TypeConstraintSigned(TypeConstraintBase):
|
||||
"""
|
||||
Contraint on whether a signed value can be used or not, or whether
|
||||
@ -326,6 +306,13 @@ class TypeVar:
|
||||
self.ctx = ctx
|
||||
self.ctx_id = ctx_id
|
||||
|
||||
def get_type(self) -> Optional[PhasmType]:
|
||||
return self.ctx.var_types[self.ctx_id]
|
||||
|
||||
def set_type(self, type_: PhasmType) -> None:
|
||||
assert self.ctx.var_types[self.ctx_id] is None, 'Type already set'
|
||||
self.ctx.var_types[self.ctx_id] = type_
|
||||
|
||||
def add_constraint(self, newconst: TypeConstraintBase) -> None:
|
||||
csts = self.ctx.var_constraints[self.ctx_id]
|
||||
|
||||
@ -345,8 +332,12 @@ class TypeVar:
|
||||
self.ctx.var_locations[self.ctx_id].add(ref)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
typ = self.ctx.var_types[self.ctx_id]
|
||||
|
||||
return (
|
||||
'TypeVar<'
|
||||
+ ('?' if typ is None else repr(typ))
|
||||
+ '; '
|
||||
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
|
||||
+ '; locations: '
|
||||
+ ', '.join(sorted(self.ctx.var_locations[self.ctx_id]))
|
||||
@ -367,16 +358,18 @@ class Context:
|
||||
|
||||
# Store the TypeVar properties as a lookup
|
||||
# so we can update these when unifying
|
||||
self.var_types: Dict[int, Optional[PhasmType]] = {}
|
||||
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
|
||||
self.var_locations: Dict[int, Set[str]] = {}
|
||||
|
||||
def new_var(self) -> TypeVar:
|
||||
def new_var(self, type_: Optional[PhasmType] = None) -> TypeVar:
|
||||
ctx_id = self.next_ctx_id
|
||||
self.next_ctx_id += 1
|
||||
|
||||
result = TypeVar(self, ctx_id)
|
||||
|
||||
self.vars_by_id[ctx_id] = [result]
|
||||
self.var_types[ctx_id] = type_
|
||||
self.var_constraints[ctx_id] = {}
|
||||
self.var_locations[ctx_id] = set()
|
||||
|
||||
@ -393,12 +386,19 @@ class Context:
|
||||
# Backup some values that we'll overwrite
|
||||
l_ctx_id = l.ctx_id
|
||||
r_ctx_id = r.ctx_id
|
||||
l_type = self.var_types[l_ctx_id]
|
||||
r_type = self.var_types[r_ctx_id]
|
||||
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
|
||||
|
||||
# Create a new TypeVar, with the combined contraints
|
||||
# and locations of the old ones
|
||||
n = self.new_var()
|
||||
|
||||
if l_type is not None and r_type is not None and l_type != r_type:
|
||||
raise TypingNarrowError(l, r, 'Type does not match')
|
||||
else:
|
||||
self.var_types[n.ctx_id] = l_type
|
||||
|
||||
try:
|
||||
for const in self.var_constraints[l_ctx_id].values():
|
||||
n.add_constraint(const)
|
||||
@ -435,15 +435,13 @@ def simplify(inp: TypeVar) -> Optional[str]:
|
||||
|
||||
Should round trip with from_str
|
||||
"""
|
||||
tc_prim = inp.get_constraint(TypeConstraintPrimitive)
|
||||
tc_bits = inp.get_constraint(TypeConstraintBitWidth)
|
||||
tc_sign = inp.get_constraint(TypeConstraintSigned)
|
||||
|
||||
if tc_prim is None:
|
||||
if inp.get_type() is None:
|
||||
return None
|
||||
|
||||
primitive = tc_prim.primitive
|
||||
if primitive is TypeConstraintPrimitive.Primitive.INT:
|
||||
if inp.get_type() is PhasmTypeInteger:
|
||||
if tc_bits is None or tc_sign is None:
|
||||
return None
|
||||
|
||||
@ -457,7 +455,7 @@ def simplify(inp: TypeVar) -> Optional[str]:
|
||||
base = 'i' if tc_sign.signed else 'u'
|
||||
return f'{base}{bitwidth}'
|
||||
|
||||
if primitive is TypeConstraintPrimitive.Primitive.FLOAT:
|
||||
if inp.get_type() is PhasmTypeReal:
|
||||
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
|
||||
return None
|
||||
|
||||
@ -470,16 +468,16 @@ def simplify(inp: TypeVar) -> Optional[str]:
|
||||
|
||||
return f'f{bitwidth}'
|
||||
|
||||
if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
||||
tc_subs = inp.get_constraint(TypeConstraintSubscript)
|
||||
assert tc_subs is not None
|
||||
assert tc_subs.members
|
||||
|
||||
sab = simplify(tc_subs.members[0])
|
||||
if sab is None:
|
||||
return None
|
||||
|
||||
return f'{sab}[{len(tc_subs.members)}]'
|
||||
# if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
||||
# tc_subs = inp.get_constraint(TypeConstraintSubscript)
|
||||
# assert tc_subs is not None
|
||||
# assert tc_subs.members
|
||||
#
|
||||
# sab = simplify(tc_subs.members[0])
|
||||
# if sab is None:
|
||||
# return None
|
||||
#
|
||||
# return f'{sab}[{len(tc_subs.members)}]'
|
||||
|
||||
return None
|
||||
|
||||
@ -487,8 +485,7 @@ def make_u8(ctx: Context) -> TypeVar:
|
||||
"""
|
||||
Makes a u8 TypeVar
|
||||
"""
|
||||
result = ctx.new_var()
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
||||
result = ctx.new_var(PhasmTypeInteger)
|
||||
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
|
||||
result.add_constraint(TypeConstraintSigned(False))
|
||||
result.add_location('u8')
|
||||
@ -498,8 +495,7 @@ def make_u32(ctx: Context) -> TypeVar:
|
||||
"""
|
||||
Makes a u32 TypeVar
|
||||
"""
|
||||
result = ctx.new_var()
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
||||
result = ctx.new_var(PhasmTypeInteger)
|
||||
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
||||
result.add_constraint(TypeConstraintSigned(False))
|
||||
result.add_location('u32')
|
||||
@ -509,8 +505,7 @@ def make_u64(ctx: Context) -> TypeVar:
|
||||
"""
|
||||
Makes a u64 TypeVar
|
||||
"""
|
||||
result = ctx.new_var()
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
||||
result = ctx.new_var(PhasmTypeInteger)
|
||||
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
|
||||
result.add_constraint(TypeConstraintSigned(False))
|
||||
result.add_location('u64')
|
||||
@ -520,8 +515,7 @@ def make_i32(ctx: Context) -> TypeVar:
|
||||
"""
|
||||
Makes a i32 TypeVar
|
||||
"""
|
||||
result = ctx.new_var()
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
||||
result = ctx.new_var(PhasmTypeInteger)
|
||||
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
||||
result.add_constraint(TypeConstraintSigned(True))
|
||||
result.add_location('i32')
|
||||
@ -531,8 +525,7 @@ def make_i64(ctx: Context) -> TypeVar:
|
||||
"""
|
||||
Makes a i64 TypeVar
|
||||
"""
|
||||
result = ctx.new_var()
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
||||
result = ctx.new_var(PhasmTypeInteger)
|
||||
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
|
||||
result.add_constraint(TypeConstraintSigned(True))
|
||||
result.add_location('i64')
|
||||
@ -542,8 +535,7 @@ def make_f32(ctx: Context) -> TypeVar:
|
||||
"""
|
||||
Makes a f32 TypeVar
|
||||
"""
|
||||
result = ctx.new_var()
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
|
||||
result = ctx.new_var(PhasmTypeReal)
|
||||
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
||||
result.add_location('f32')
|
||||
return result
|
||||
@ -552,8 +544,7 @@ def make_f64(ctx: Context) -> TypeVar:
|
||||
"""
|
||||
Makes a f64 TypeVar
|
||||
"""
|
||||
result = ctx.new_var()
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
|
||||
result = ctx.new_var(PhasmTypeReal)
|
||||
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
|
||||
result.add_location('f64')
|
||||
return result
|
||||
@ -588,23 +579,23 @@ def from_str(ctx: Context, inp: str, location: Optional[str] = None) -> TypeVar:
|
||||
result.add_location(location)
|
||||
return result
|
||||
|
||||
match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
|
||||
if match:
|
||||
result = ctx.new_var()
|
||||
|
||||
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
|
||||
result.add_constraint(TypeConstraintSubscript(members=(
|
||||
# Make copies so they don't get entangled
|
||||
# with each other.
|
||||
from_str(ctx, match[1], match[1])
|
||||
for _ in range(int(match[2]))
|
||||
)))
|
||||
|
||||
result.add_location(inp)
|
||||
|
||||
if location is not None:
|
||||
result.add_location(location)
|
||||
|
||||
return result
|
||||
# match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
|
||||
# if match:
|
||||
# result = ctx.new_var()
|
||||
#
|
||||
# result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
|
||||
# result.add_constraint(TypeConstraintSubscript(members=(
|
||||
# # Make copies so they don't get entangled
|
||||
# # with each other.
|
||||
# from_str(ctx, match[1], match[1])
|
||||
# for _ in range(int(match[2]))
|
||||
# )))
|
||||
#
|
||||
# result.add_location(inp)
|
||||
#
|
||||
# if location is not None:
|
||||
# result.add_location(location)
|
||||
#
|
||||
# return result
|
||||
|
||||
raise NotImplementedError(from_str, inp)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user