604 lines
18 KiB
Python
604 lines
18 KiB
Python
"""
|
|
The phasm type system
|
|
"""
|
|
from typing import Callable, Dict, Iterable, Optional, List, Set, Type
|
|
from typing import TypeVar as MyPyTypeVar
|
|
|
|
import enum
|
|
import re
|
|
|
|
from .exceptions import TypingError
|
|
|
|
class TypeBase:
|
|
"""
|
|
TypeBase base class
|
|
"""
|
|
__slots__ = ()
|
|
|
|
def alloc_size(self) -> int:
|
|
"""
|
|
When allocating this type in memory, how many bytes do we need to reserve?
|
|
"""
|
|
raise NotImplementedError(self, 'alloc_size')
|
|
|
|
class TypeBytes(TypeBase):
|
|
"""
|
|
The bytes type
|
|
"""
|
|
__slots__ = ()
|
|
|
|
class TypeTupleMember:
|
|
"""
|
|
Represents a tuple member
|
|
"""
|
|
def __init__(self, idx: int, type_: TypeBase, offset: int) -> None:
|
|
self.idx = idx
|
|
self.type = type_
|
|
self.offset = offset
|
|
|
|
class TypeTuple(TypeBase):
|
|
"""
|
|
The tuple type
|
|
"""
|
|
__slots__ = ('members', )
|
|
|
|
members: List[TypeTupleMember]
|
|
|
|
def __init__(self) -> None:
|
|
self.members = []
|
|
|
|
def render_internal_name(self) -> str:
|
|
"""
|
|
Generates an internal name for this tuple
|
|
"""
|
|
mems = '@'.join('?' for x in self.members) # FIXME: Should not be a questionmark
|
|
assert ' ' not in mems, 'Not implement yet: subtuples'
|
|
return f'tuple@{mems}'
|
|
|
|
def alloc_size(self) -> int:
|
|
return sum(
|
|
x.type.alloc_size()
|
|
for x in self.members
|
|
)
|
|
|
|
class TypeStaticArrayMember:
|
|
"""
|
|
Represents a static array member
|
|
"""
|
|
def __init__(self, idx: int, offset: int) -> None:
|
|
self.idx = idx
|
|
self.offset = offset
|
|
|
|
class TypeStaticArray(TypeBase):
|
|
"""
|
|
The static array type
|
|
"""
|
|
__slots__ = ('member_type', 'members', )
|
|
|
|
member_type: TypeBase
|
|
members: List[TypeStaticArrayMember]
|
|
|
|
def __init__(self, member_type: TypeBase) -> None:
|
|
self.member_type = member_type
|
|
self.members = []
|
|
|
|
def alloc_size(self) -> int:
|
|
return self.member_type.alloc_size() * len(self.members)
|
|
|
|
class TypeStructMember:
|
|
"""
|
|
Represents a struct member
|
|
"""
|
|
def __init__(self, name: str, type_: TypeBase, offset: int) -> None:
|
|
self.name = name
|
|
self.type = type_
|
|
self.offset = offset
|
|
|
|
class TypeStruct(TypeBase):
|
|
"""
|
|
A struct has named properties
|
|
"""
|
|
__slots__ = ('name', 'lineno', 'members', )
|
|
|
|
name: str
|
|
lineno: int
|
|
members: List[TypeStructMember]
|
|
|
|
def __init__(self, name: str, lineno: int) -> None:
|
|
self.name = name
|
|
self.lineno = lineno
|
|
self.members = []
|
|
|
|
def get_member(self, name: str) -> Optional[TypeStructMember]:
|
|
"""
|
|
Returns a member by name
|
|
"""
|
|
for mem in self.members:
|
|
if mem.name == name:
|
|
return mem
|
|
|
|
return None
|
|
|
|
def alloc_size(self) -> int:
|
|
return sum(
|
|
x.type.alloc_size()
|
|
for x in self.members
|
|
)
|
|
|
|
## NEW STUFF BELOW
|
|
|
|
# This error can also mean that the typer somewhere forgot to write a type
|
|
# back to the AST. If so, we need to fix the typer.
|
|
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
|
|
|
|
|
|
class TypingNarrowProtoError(TypingError):
|
|
"""
|
|
A proto error when trying to narrow two types
|
|
|
|
This gets turned into a TypingNarrowError by the unify method
|
|
"""
|
|
# FIXME: Use consistent naming for unify / narrow / entangle
|
|
|
|
class TypingNarrowError(TypingError):
|
|
"""
|
|
An error when trying to unify two Type Variables
|
|
"""
|
|
def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None:
|
|
super().__init__(
|
|
f'Cannot narrow types {l} and {r}: {msg}'
|
|
)
|
|
|
|
class TypeConstraintBase:
|
|
"""
|
|
Base class for classes implementing a contraint on a type
|
|
"""
|
|
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase':
|
|
raise NotImplementedError('narrow', self, other)
|
|
|
|
class TypeConstraintPrimitive(TypeConstraintBase):
|
|
"""
|
|
This contraint on a type defines its primitive shape
|
|
"""
|
|
__slots__ = ('primitive', )
|
|
|
|
class Primitive(enum.Enum):
|
|
"""
|
|
The primitive ID
|
|
"""
|
|
INT = 0
|
|
FLOAT = 1
|
|
|
|
STATIC_ARRAY = 10
|
|
|
|
primitive: Primitive
|
|
|
|
def __init__(self, primitive: Primitive) -> None:
|
|
self.primitive = primitive
|
|
|
|
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
|
|
if not isinstance(other, TypeConstraintPrimitive):
|
|
raise Exception('Invalid comparison')
|
|
|
|
if self.primitive != other.primitive:
|
|
raise TypingNarrowProtoError('Primitive does not match')
|
|
|
|
return TypeConstraintPrimitive(self.primitive)
|
|
|
|
def __repr__(self) -> str:
|
|
return f'Primitive={self.primitive.name}'
|
|
|
|
class TypeConstraintSigned(TypeConstraintBase):
|
|
"""
|
|
Contraint on whether a signed value can be used or not, or whether
|
|
a value can be used in a signed expression
|
|
"""
|
|
__slots__ = ('signed', )
|
|
|
|
signed: Optional[bool]
|
|
|
|
def __init__(self, signed: Optional[bool]) -> None:
|
|
self.signed = signed
|
|
|
|
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSigned':
|
|
if not isinstance(other, TypeConstraintSigned):
|
|
raise Exception('Invalid comparison')
|
|
|
|
if other.signed is None:
|
|
return TypeConstraintSigned(self.signed)
|
|
if self.signed is None:
|
|
return TypeConstraintSigned(other.signed)
|
|
|
|
if self.signed is not other.signed:
|
|
raise TypingNarrowProtoError('Signed does not match')
|
|
|
|
return TypeConstraintSigned(self.signed)
|
|
|
|
def __repr__(self) -> str:
|
|
return f'Signed={self.signed}'
|
|
|
|
class TypeConstraintBitWidth(TypeConstraintBase):
|
|
"""
|
|
Contraint on how many bits an expression has or can possibly have
|
|
"""
|
|
__slots__ = ('oneof', )
|
|
|
|
oneof: Set[int]
|
|
|
|
def __init__(self, *, oneof: Optional[Iterable[int]] = None, minb: Optional[int] = None, maxb: Optional[int] = None) -> None:
|
|
# For now, support up to 64 bits values
|
|
self.oneof = set(oneof) if oneof is not None else set(range(1, 65))
|
|
|
|
if minb is not None:
|
|
self.oneof = {
|
|
x
|
|
for x in self.oneof
|
|
if minb <= x
|
|
}
|
|
|
|
if maxb is not None:
|
|
self.oneof = {
|
|
x
|
|
for x in self.oneof
|
|
if x <= maxb
|
|
}
|
|
|
|
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth':
|
|
if not isinstance(other, TypeConstraintBitWidth):
|
|
raise Exception('Invalid comparison')
|
|
|
|
new_oneof = self.oneof & other.oneof
|
|
|
|
if not new_oneof:
|
|
raise TypingNarrowProtoError('Memory width cannot be resolved')
|
|
|
|
return TypeConstraintBitWidth(oneof=new_oneof)
|
|
|
|
def __repr__(self) -> str:
|
|
result = 'BitWidth='
|
|
|
|
items = list(sorted(self.oneof))
|
|
if not items:
|
|
return result
|
|
|
|
while items:
|
|
itm = items.pop(0)
|
|
result += str(itm)
|
|
|
|
cnt = 0
|
|
while cnt < len(items) and items[cnt] == itm + cnt + 1:
|
|
cnt += 1
|
|
|
|
if cnt == 1:
|
|
result += ',' + str(items[0])
|
|
elif cnt > 1:
|
|
result += '..' + str(items[cnt - 1])
|
|
|
|
items = items[cnt:]
|
|
if items:
|
|
result += ','
|
|
|
|
return result
|
|
|
|
class TypeConstraintSubscript(TypeConstraintBase):
|
|
"""
|
|
Contraint on allowing a type to be subscripted
|
|
"""
|
|
__slots__ = ('members', )
|
|
|
|
members: List['TypeVar']
|
|
|
|
def __init__(self, *, members: Iterable['TypeVar']) -> None:
|
|
self.members = list(members)
|
|
|
|
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSubscript':
|
|
if not isinstance(other, TypeConstraintSubscript):
|
|
raise Exception('Invalid comparison')
|
|
|
|
if len(self.members) != len(other.members):
|
|
raise TypingNarrowProtoError('Member count does not match')
|
|
|
|
newmembers = []
|
|
for smb, omb in zip(self.members, other.members):
|
|
nmb = ctx.new_var()
|
|
ctx.unify(nmb, smb)
|
|
ctx.unify(nmb, omb)
|
|
newmembers.append(nmb)
|
|
|
|
return TypeConstraintSubscript(members=newmembers)
|
|
|
|
def __repr__(self) -> str:
|
|
return 'Subscript=(' + ','.join(map(repr, self.members)) + ')'
|
|
|
|
TTypeConstraintClass = MyPyTypeVar('TTypeConstraintClass', bound=TypeConstraintBase)
|
|
|
|
class TypeVar:
|
|
"""
|
|
A type variable
|
|
"""
|
|
# FIXME: Explain the type system
|
|
__slots__ = ('ctx', 'ctx_id', )
|
|
|
|
ctx: 'Context'
|
|
ctx_id: int
|
|
|
|
def __init__(self, ctx: 'Context', ctx_id: int) -> None:
|
|
self.ctx = ctx
|
|
self.ctx_id = ctx_id
|
|
|
|
def add_constraint(self, newconst: TypeConstraintBase) -> None:
|
|
csts = self.ctx.var_constraints[self.ctx_id]
|
|
|
|
if newconst.__class__ in csts:
|
|
csts[newconst.__class__] = csts[newconst.__class__].narrow(self.ctx, newconst)
|
|
else:
|
|
csts[newconst.__class__] = newconst
|
|
|
|
def get_constraint(self, const_type: Type[TTypeConstraintClass]) -> Optional[TTypeConstraintClass]:
|
|
csts = self.ctx.var_constraints[self.ctx_id]
|
|
|
|
res = csts.get(const_type, None)
|
|
assert res is None or isinstance(res, const_type) # type hint
|
|
return res
|
|
|
|
def add_location(self, ref: str) -> None:
|
|
self.ctx.var_locations[self.ctx_id].add(ref)
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
'TypeVar<'
|
|
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
|
|
+ '; locations: '
|
|
+ ', '.join(sorted(self.ctx.var_locations[self.ctx_id]))
|
|
+ '>'
|
|
)
|
|
|
|
class Context:
|
|
"""
|
|
The context for a collection of type variables
|
|
"""
|
|
def __init__(self) -> None:
|
|
# Variables are unified (or entangled, if you will)
|
|
# that means that each TypeVar within a context has an ID,
|
|
# and all TypeVars with the same ID are the same TypeVar,
|
|
# even if they are a different instance
|
|
self.next_ctx_id = 1
|
|
self.vars_by_id: Dict[int, List[TypeVar]] = {}
|
|
|
|
# Store the TypeVar properties as a lookup
|
|
# so we can update these when unifying
|
|
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
|
|
self.var_locations: Dict[int, Set[str]] = {}
|
|
|
|
def new_var(self) -> TypeVar:
|
|
ctx_id = self.next_ctx_id
|
|
self.next_ctx_id += 1
|
|
|
|
result = TypeVar(self, ctx_id)
|
|
|
|
self.vars_by_id[ctx_id] = [result]
|
|
self.var_constraints[ctx_id] = {}
|
|
self.var_locations[ctx_id] = set()
|
|
|
|
return result
|
|
|
|
def unify(self, l: Optional[TypeVar], r: Optional[TypeVar]) -> None:
|
|
# FIXME: Write method doc, find out why pylint doesn't error
|
|
|
|
assert l is not None, ASSERTION_ERROR
|
|
assert r is not None, ASSERTION_ERROR
|
|
|
|
assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return
|
|
|
|
# Backup some values that we'll overwrite
|
|
l_ctx_id = l.ctx_id
|
|
r_ctx_id = r.ctx_id
|
|
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
|
|
|
|
# Create a new TypeVar, with the combined contraints
|
|
# and locations of the old ones
|
|
n = self.new_var()
|
|
|
|
try:
|
|
for const in self.var_constraints[l_ctx_id].values():
|
|
n.add_constraint(const)
|
|
for const in self.var_constraints[r_ctx_id].values():
|
|
n.add_constraint(const)
|
|
except TypingNarrowProtoError as exc:
|
|
raise TypingNarrowError(l, r, str(exc)) from None
|
|
|
|
self.var_locations[n.ctx_id] = self.var_locations[l_ctx_id] | self.var_locations[r_ctx_id]
|
|
|
|
# ##
|
|
# And unify (or entangle) the old ones
|
|
|
|
# First update the IDs, so they all point to the new list
|
|
for type_var in l_r_var_list:
|
|
type_var.ctx_id = n.ctx_id
|
|
|
|
# Update our registry of TypeVars by ID, so we can find them
|
|
# on the next unify
|
|
self.vars_by_id[n.ctx_id].extend(l_r_var_list)
|
|
|
|
# Then delete the old values for the now gone variables
|
|
# Do this last, so exceptions thrown in the code above
|
|
# still have a valid context
|
|
del self.var_constraints[l_ctx_id]
|
|
del self.var_constraints[r_ctx_id]
|
|
del self.var_locations[l_ctx_id]
|
|
del self.var_locations[r_ctx_id]
|
|
|
|
def simplify(inp: TypeVar) -> Optional[str]:
|
|
"""
|
|
Simplifies a TypeVar into a string that wasm can work with
|
|
and users can recognize
|
|
|
|
Should round trip with from_str
|
|
"""
|
|
tc_prim = inp.get_constraint(TypeConstraintPrimitive)
|
|
tc_bits = inp.get_constraint(TypeConstraintBitWidth)
|
|
tc_sign = inp.get_constraint(TypeConstraintSigned)
|
|
|
|
if tc_prim is None:
|
|
return None
|
|
|
|
primitive = tc_prim.primitive
|
|
if primitive is TypeConstraintPrimitive.Primitive.INT:
|
|
if tc_bits is None or tc_sign is None:
|
|
return None
|
|
|
|
if tc_sign.signed is None or len(tc_bits.oneof) != 1:
|
|
return None
|
|
|
|
bitwidth = next(iter(tc_bits.oneof))
|
|
if bitwidth not in (8, 32, 64):
|
|
return None
|
|
|
|
base = 'i' if tc_sign.signed else 'u'
|
|
return f'{base}{bitwidth}'
|
|
|
|
if primitive is TypeConstraintPrimitive.Primitive.FLOAT:
|
|
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
|
|
return None
|
|
|
|
if len(tc_bits.oneof) != 1:
|
|
return None
|
|
|
|
bitwidth = next(iter(tc_bits.oneof))
|
|
if bitwidth not in (32, 64):
|
|
return None
|
|
|
|
return f'f{bitwidth}'
|
|
|
|
if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
|
|
tc_subs = inp.get_constraint(TypeConstraintSubscript)
|
|
assert tc_subs is not None
|
|
assert tc_subs.members
|
|
|
|
sab = simplify(tc_subs.members[0])
|
|
if sab is None:
|
|
return None
|
|
|
|
return f'{sab}[{len(tc_subs.members)}]'
|
|
|
|
return None
|
|
|
|
def make_u8(ctx: Context, location: str) -> TypeVar:
|
|
"""
|
|
Makes a u8 TypeVar
|
|
"""
|
|
result = ctx.new_var()
|
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
|
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
|
|
result.add_constraint(TypeConstraintSigned(False))
|
|
result.add_location(location)
|
|
return result
|
|
|
|
def make_u32(ctx: Context, location: str) -> TypeVar:
|
|
"""
|
|
Makes a u32 TypeVar
|
|
"""
|
|
result = ctx.new_var()
|
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
|
result.add_constraint(TypeConstraintSigned(False))
|
|
result.add_location(location)
|
|
return result
|
|
|
|
def make_u64(ctx: Context, location: str) -> TypeVar:
|
|
"""
|
|
Makes a u64 TypeVar
|
|
"""
|
|
result = ctx.new_var()
|
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
|
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
|
|
result.add_constraint(TypeConstraintSigned(False))
|
|
result.add_location(location)
|
|
return result
|
|
|
|
def make_i32(ctx: Context, location: str) -> TypeVar:
|
|
"""
|
|
Makes a i32 TypeVar
|
|
"""
|
|
result = ctx.new_var()
|
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
|
result.add_constraint(TypeConstraintSigned(True))
|
|
result.add_location(location)
|
|
return result
|
|
|
|
def make_i64(ctx: Context, location: str) -> TypeVar:
|
|
"""
|
|
Makes a i64 TypeVar
|
|
"""
|
|
result = ctx.new_var()
|
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
|
|
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
|
|
result.add_constraint(TypeConstraintSigned(True))
|
|
result.add_location(location)
|
|
return result
|
|
|
|
def make_f32(ctx: Context, location: str) -> TypeVar:
|
|
"""
|
|
Makes a f32 TypeVar
|
|
"""
|
|
result = ctx.new_var()
|
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
|
|
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
|
|
result.add_location(location)
|
|
return result
|
|
|
|
def make_f64(ctx: Context, location: str) -> TypeVar:
|
|
"""
|
|
Makes a f64 TypeVar
|
|
"""
|
|
result = ctx.new_var()
|
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
|
|
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
|
|
result.add_location(location)
|
|
return result
|
|
|
|
BUILTIN_TYPES: Dict[str, Callable[[Context, str], TypeVar]] = {
|
|
'u8': make_u8,
|
|
'u32': make_u32,
|
|
'u64': make_u64,
|
|
'i32': make_i32,
|
|
'i64': make_i64,
|
|
'f32': make_f32,
|
|
'f64': make_f64,
|
|
}
|
|
|
|
TYPE_MATCH_STATIC_ARRAY = re.compile(r'^([uif][0-9]+)\[([0-9]+)\]')
|
|
|
|
def from_str(ctx: Context, inp: str, location: str) -> TypeVar:
|
|
"""
|
|
Creates a new TypeVar from the string
|
|
|
|
Should round trip with simplify
|
|
|
|
The location is a reference to where you found the string
|
|
in the source code.
|
|
|
|
This could be conidered part of parsing. Though that would give trouble
|
|
with the context creation.
|
|
"""
|
|
if inp in BUILTIN_TYPES:
|
|
return BUILTIN_TYPES[inp](ctx, location)
|
|
|
|
match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
|
|
if match:
|
|
result = ctx.new_var()
|
|
|
|
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
|
|
result.add_constraint(TypeConstraintSubscript(members=(
|
|
# Make copies so they don't get entangled
|
|
# with each other.
|
|
from_str(ctx, match[1], match[1])
|
|
for _ in range(int(match[2]))
|
|
)))
|
|
result.add_location(location)
|
|
|
|
return result
|
|
|
|
raise NotImplementedError(from_str, inp)
|