phasm/phasm/typing.py
Johan B.W. de Vries 4d3c0c6c3c StaticArray with constant index works again
Also, fix issue with f64 being parsed as f32
2022-09-19 14:43:15 +02:00

604 lines
18 KiB
Python

"""
The phasm type system
"""
from typing import Callable, Dict, Iterable, Optional, List, Set, Type
from typing import TypeVar as MyPyTypeVar
import enum
import re
from .exceptions import TypingError
class TypeBase:
"""
TypeBase base class
"""
__slots__ = ()
def alloc_size(self) -> int:
"""
When allocating this type in memory, how many bytes do we need to reserve?
"""
raise NotImplementedError(self, 'alloc_size')
class TypeBytes(TypeBase):
"""
The bytes type
"""
__slots__ = ()
class TypeTupleMember:
"""
Represents a tuple member
"""
def __init__(self, idx: int, type_: TypeBase, offset: int) -> None:
self.idx = idx
self.type = type_
self.offset = offset
class TypeTuple(TypeBase):
"""
The tuple type
"""
__slots__ = ('members', )
members: List[TypeTupleMember]
def __init__(self) -> None:
self.members = []
def render_internal_name(self) -> str:
"""
Generates an internal name for this tuple
"""
mems = '@'.join('?' for x in self.members) # FIXME: Should not be a questionmark
assert ' ' not in mems, 'Not implement yet: subtuples'
return f'tuple@{mems}'
def alloc_size(self) -> int:
return sum(
x.type.alloc_size()
for x in self.members
)
class TypeStaticArrayMember:
"""
Represents a static array member
"""
def __init__(self, idx: int, offset: int) -> None:
self.idx = idx
self.offset = offset
class TypeStaticArray(TypeBase):
"""
The static array type
"""
__slots__ = ('member_type', 'members', )
member_type: TypeBase
members: List[TypeStaticArrayMember]
def __init__(self, member_type: TypeBase) -> None:
self.member_type = member_type
self.members = []
def alloc_size(self) -> int:
return self.member_type.alloc_size() * len(self.members)
class TypeStructMember:
"""
Represents a struct member
"""
def __init__(self, name: str, type_: TypeBase, offset: int) -> None:
self.name = name
self.type = type_
self.offset = offset
class TypeStruct(TypeBase):
"""
A struct has named properties
"""
__slots__ = ('name', 'lineno', 'members', )
name: str
lineno: int
members: List[TypeStructMember]
def __init__(self, name: str, lineno: int) -> None:
self.name = name
self.lineno = lineno
self.members = []
def get_member(self, name: str) -> Optional[TypeStructMember]:
"""
Returns a member by name
"""
for mem in self.members:
if mem.name == name:
return mem
return None
def alloc_size(self) -> int:
return sum(
x.type.alloc_size()
for x in self.members
)
## NEW STUFF BELOW
# This error can also mean that the typer somewhere forgot to write a type
# back to the AST. If so, we need to fix the typer.
ASSERTION_ERROR = 'You must call phasm_type after calling phasm_parse before you can call any other method'
class TypingNarrowProtoError(TypingError):
"""
A proto error when trying to narrow two types
This gets turned into a TypingNarrowError by the unify method
"""
# FIXME: Use consistent naming for unify / narrow / entangle
class TypingNarrowError(TypingError):
"""
An error when trying to unify two Type Variables
"""
def __init__(self, l: 'TypeVar', r: 'TypeVar', msg: str) -> None:
super().__init__(
f'Cannot narrow types {l} and {r}: {msg}'
)
class TypeConstraintBase:
"""
Base class for classes implementing a contraint on a type
"""
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBase':
raise NotImplementedError('narrow', self, other)
class TypeConstraintPrimitive(TypeConstraintBase):
"""
This contraint on a type defines its primitive shape
"""
__slots__ = ('primitive', )
class Primitive(enum.Enum):
"""
The primitive ID
"""
INT = 0
FLOAT = 1
STATIC_ARRAY = 10
primitive: Primitive
def __init__(self, primitive: Primitive) -> None:
self.primitive = primitive
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintPrimitive':
if not isinstance(other, TypeConstraintPrimitive):
raise Exception('Invalid comparison')
if self.primitive != other.primitive:
raise TypingNarrowProtoError('Primitive does not match')
return TypeConstraintPrimitive(self.primitive)
def __repr__(self) -> str:
return f'Primitive={self.primitive.name}'
class TypeConstraintSigned(TypeConstraintBase):
"""
Contraint on whether a signed value can be used or not, or whether
a value can be used in a signed expression
"""
__slots__ = ('signed', )
signed: Optional[bool]
def __init__(self, signed: Optional[bool]) -> None:
self.signed = signed
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSigned':
if not isinstance(other, TypeConstraintSigned):
raise Exception('Invalid comparison')
if other.signed is None:
return TypeConstraintSigned(self.signed)
if self.signed is None:
return TypeConstraintSigned(other.signed)
if self.signed is not other.signed:
raise TypingNarrowProtoError('Signed does not match')
return TypeConstraintSigned(self.signed)
def __repr__(self) -> str:
return f'Signed={self.signed}'
class TypeConstraintBitWidth(TypeConstraintBase):
"""
Contraint on how many bits an expression has or can possibly have
"""
__slots__ = ('oneof', )
oneof: Set[int]
def __init__(self, *, oneof: Optional[Iterable[int]] = None, minb: Optional[int] = None, maxb: Optional[int] = None) -> None:
# For now, support up to 64 bits values
self.oneof = set(oneof) if oneof is not None else set(range(1, 65))
if minb is not None:
self.oneof = {
x
for x in self.oneof
if minb <= x
}
if maxb is not None:
self.oneof = {
x
for x in self.oneof
if x <= maxb
}
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintBitWidth':
if not isinstance(other, TypeConstraintBitWidth):
raise Exception('Invalid comparison')
new_oneof = self.oneof & other.oneof
if not new_oneof:
raise TypingNarrowProtoError('Memory width cannot be resolved')
return TypeConstraintBitWidth(oneof=new_oneof)
def __repr__(self) -> str:
result = 'BitWidth='
items = list(sorted(self.oneof))
if not items:
return result
while items:
itm = items.pop(0)
result += str(itm)
cnt = 0
while cnt < len(items) and items[cnt] == itm + cnt + 1:
cnt += 1
if cnt == 1:
result += ',' + str(items[0])
elif cnt > 1:
result += '..' + str(items[cnt - 1])
items = items[cnt:]
if items:
result += ','
return result
class TypeConstraintSubscript(TypeConstraintBase):
"""
Contraint on allowing a type to be subscripted
"""
__slots__ = ('members', )
members: List['TypeVar']
def __init__(self, *, members: Iterable['TypeVar']) -> None:
self.members = list(members)
def narrow(self, ctx: 'Context', other: 'TypeConstraintBase') -> 'TypeConstraintSubscript':
if not isinstance(other, TypeConstraintSubscript):
raise Exception('Invalid comparison')
if len(self.members) != len(other.members):
raise TypingNarrowProtoError('Member count does not match')
newmembers = []
for smb, omb in zip(self.members, other.members):
nmb = ctx.new_var()
ctx.unify(nmb, smb)
ctx.unify(nmb, omb)
newmembers.append(nmb)
return TypeConstraintSubscript(members=newmembers)
def __repr__(self) -> str:
return 'Subscript=(' + ','.join(map(repr, self.members)) + ')'
TTypeConstraintClass = MyPyTypeVar('TTypeConstraintClass', bound=TypeConstraintBase)
class TypeVar:
"""
A type variable
"""
# FIXME: Explain the type system
__slots__ = ('ctx', 'ctx_id', )
ctx: 'Context'
ctx_id: int
def __init__(self, ctx: 'Context', ctx_id: int) -> None:
self.ctx = ctx
self.ctx_id = ctx_id
def add_constraint(self, newconst: TypeConstraintBase) -> None:
csts = self.ctx.var_constraints[self.ctx_id]
if newconst.__class__ in csts:
csts[newconst.__class__] = csts[newconst.__class__].narrow(self.ctx, newconst)
else:
csts[newconst.__class__] = newconst
def get_constraint(self, const_type: Type[TTypeConstraintClass]) -> Optional[TTypeConstraintClass]:
csts = self.ctx.var_constraints[self.ctx_id]
res = csts.get(const_type, None)
assert res is None or isinstance(res, const_type) # type hint
return res
def add_location(self, ref: str) -> None:
self.ctx.var_locations[self.ctx_id].add(ref)
def __repr__(self) -> str:
return (
'TypeVar<'
+ '; '.join(map(repr, self.ctx.var_constraints[self.ctx_id].values()))
+ '; locations: '
+ ', '.join(sorted(self.ctx.var_locations[self.ctx_id]))
+ '>'
)
class Context:
"""
The context for a collection of type variables
"""
def __init__(self) -> None:
# Variables are unified (or entangled, if you will)
# that means that each TypeVar within a context has an ID,
# and all TypeVars with the same ID are the same TypeVar,
# even if they are a different instance
self.next_ctx_id = 1
self.vars_by_id: Dict[int, List[TypeVar]] = {}
# Store the TypeVar properties as a lookup
# so we can update these when unifying
self.var_constraints: Dict[int, Dict[Type[TypeConstraintBase], TypeConstraintBase]] = {}
self.var_locations: Dict[int, Set[str]] = {}
def new_var(self) -> TypeVar:
ctx_id = self.next_ctx_id
self.next_ctx_id += 1
result = TypeVar(self, ctx_id)
self.vars_by_id[ctx_id] = [result]
self.var_constraints[ctx_id] = {}
self.var_locations[ctx_id] = set()
return result
def unify(self, l: Optional[TypeVar], r: Optional[TypeVar]) -> None:
# FIXME: Write method doc, find out why pylint doesn't error
assert l is not None, ASSERTION_ERROR
assert r is not None, ASSERTION_ERROR
assert l.ctx_id != r.ctx_id # Dunno if this'll happen, if so, just return
# Backup some values that we'll overwrite
l_ctx_id = l.ctx_id
r_ctx_id = r.ctx_id
l_r_var_list = self.vars_by_id[l_ctx_id] + self.vars_by_id[r_ctx_id]
# Create a new TypeVar, with the combined contraints
# and locations of the old ones
n = self.new_var()
try:
for const in self.var_constraints[l_ctx_id].values():
n.add_constraint(const)
for const in self.var_constraints[r_ctx_id].values():
n.add_constraint(const)
except TypingNarrowProtoError as exc:
raise TypingNarrowError(l, r, str(exc)) from None
self.var_locations[n.ctx_id] = self.var_locations[l_ctx_id] | self.var_locations[r_ctx_id]
# ##
# And unify (or entangle) the old ones
# First update the IDs, so they all point to the new list
for type_var in l_r_var_list:
type_var.ctx_id = n.ctx_id
# Update our registry of TypeVars by ID, so we can find them
# on the next unify
self.vars_by_id[n.ctx_id].extend(l_r_var_list)
# Then delete the old values for the now gone variables
# Do this last, so exceptions thrown in the code above
# still have a valid context
del self.var_constraints[l_ctx_id]
del self.var_constraints[r_ctx_id]
del self.var_locations[l_ctx_id]
del self.var_locations[r_ctx_id]
def simplify(inp: TypeVar) -> Optional[str]:
"""
Simplifies a TypeVar into a string that wasm can work with
and users can recognize
Should round trip with from_str
"""
tc_prim = inp.get_constraint(TypeConstraintPrimitive)
tc_bits = inp.get_constraint(TypeConstraintBitWidth)
tc_sign = inp.get_constraint(TypeConstraintSigned)
if tc_prim is None:
return None
primitive = tc_prim.primitive
if primitive is TypeConstraintPrimitive.Primitive.INT:
if tc_bits is None or tc_sign is None:
return None
if tc_sign.signed is None or len(tc_bits.oneof) != 1:
return None
bitwidth = next(iter(tc_bits.oneof))
if bitwidth not in (8, 32, 64):
return None
base = 'i' if tc_sign.signed else 'u'
return f'{base}{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.FLOAT:
if tc_bits is None or tc_sign is not None: # Floats should not hava sign contraint
return None
if len(tc_bits.oneof) != 1:
return None
bitwidth = next(iter(tc_bits.oneof))
if bitwidth not in (32, 64):
return None
return f'f{bitwidth}'
if primitive is TypeConstraintPrimitive.Primitive.STATIC_ARRAY:
tc_subs = inp.get_constraint(TypeConstraintSubscript)
assert tc_subs is not None
assert tc_subs.members
sab = simplify(tc_subs.members[0])
if sab is None:
return None
return f'{sab}[{len(tc_subs.members)}]'
return None
def make_u8(ctx: Context, location: str) -> TypeVar:
"""
Makes a u8 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=8, maxb=8))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
def make_u32(ctx: Context, location: str) -> TypeVar:
"""
Makes a u32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
def make_u64(ctx: Context, location: str) -> TypeVar:
"""
Makes a u64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(False))
result.add_location(location)
return result
def make_i32(ctx: Context, location: str) -> TypeVar:
"""
Makes a i32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
def make_i64(ctx: Context, location: str) -> TypeVar:
"""
Makes a i64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.INT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_constraint(TypeConstraintSigned(True))
result.add_location(location)
return result
def make_f32(ctx: Context, location: str) -> TypeVar:
"""
Makes a f32 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=32, maxb=32))
result.add_location(location)
return result
def make_f64(ctx: Context, location: str) -> TypeVar:
"""
Makes a f64 TypeVar
"""
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.FLOAT))
result.add_constraint(TypeConstraintBitWidth(minb=64, maxb=64))
result.add_location(location)
return result
BUILTIN_TYPES: Dict[str, Callable[[Context, str], TypeVar]] = {
'u8': make_u8,
'u32': make_u32,
'u64': make_u64,
'i32': make_i32,
'i64': make_i64,
'f32': make_f32,
'f64': make_f64,
}
TYPE_MATCH_STATIC_ARRAY = re.compile(r'^([uif][0-9]+)\[([0-9]+)\]')
def from_str(ctx: Context, inp: str, location: str) -> TypeVar:
"""
Creates a new TypeVar from the string
Should round trip with simplify
The location is a reference to where you found the string
in the source code.
This could be conidered part of parsing. Though that would give trouble
with the context creation.
"""
if inp in BUILTIN_TYPES:
return BUILTIN_TYPES[inp](ctx, location)
match = TYPE_MATCH_STATIC_ARRAY.fullmatch(inp)
if match:
result = ctx.new_var()
result.add_constraint(TypeConstraintPrimitive(TypeConstraintPrimitive.Primitive.STATIC_ARRAY))
result.add_constraint(TypeConstraintSubscript(members=(
# Make copies so they don't get entangled
# with each other.
from_str(ctx, match[1], match[1])
for _ in range(int(match[2]))
)))
result.add_location(location)
return result
raise NotImplementedError(from_str, inp)