phasm/phasm/type3/typeclasses.py
Johan B.W. de Vries 891f114edf Ideas [skip-ci]
2025-04-13 14:08:27 +02:00

159 lines
4.5 KiB
Python

from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
class TypeVariable:
__slots__ = ('letter', )
letter: str
def __init__(self, letter: str) -> None:
assert len(letter) == 1, f'{letter} is not a valid type variable'
self.letter = letter
def __hash__(self) -> int:
return hash(self.letter)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeVariable):
raise NotImplementedError
return self.letter == other.letter
def __repr__(self) -> str:
return f'TypeVariable({repr(self.letter)})'
class TypeReference:
__slots__ = ('name', )
name: str
def __init__(self, name: str) -> None:
assert len(name) > 1, f'{name} is not a valid type reference'
self.name = name
def __hash__(self) -> int:
return hash(self.name)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeReference):
raise NotImplementedError
return self.name == other.name
def __repr__(self) -> str:
return f'TypeReference({repr(self.name)})'
class Type3ClassMethod:
__slots__ = ('type3_class', 'name', 'signature', )
type3_class: 'Type3Class'
name: str
signature: List[Union[TypeReference, TypeVariable]]
def __init__(self, type3_class: 'Type3Class', name: str, signature: str) -> None:
self.type3_class = type3_class
self.name = name
self.signature = [
TypeVariable(x) if len(x) == 1 else TypeReference(x)
for x in signature.split(' -> ')
]
def __repr__(self) -> str:
return f'Type3ClassMethod({repr(self.type3_class)}, {repr(self.name)}, {repr(self.signature)})'
class Type3Class:
__slots__ = ('name', 'args', 'methods', 'operators', 'inherited_classes', )
name: str
args: List[TypeVariable]
methods: Dict[str, Type3ClassMethod]
operators: Dict[str, Type3ClassMethod]
inherited_classes: List['Type3Class']
def __init__(
self,
name: str,
args: Iterable[str],
methods: Mapping[str, str],
operators: Mapping[str, str],
inherited_classes: Optional[List['Type3Class']] = None,
) -> None:
self.name = name
self.args = [TypeVariable(x) for x in args]
self.methods = {
k: Type3ClassMethod(self, k, v)
for k, v in methods.items()
}
self.operators = {
k: Type3ClassMethod(self, k, v)
for k, v in operators.items()
}
self.inherited_classes = inherited_classes or []
def __repr__(self) -> str:
return self.name
InternalPassAsPointer = Type3Class('InternalPassAsPointer', ['a'], methods={}, operators={})
Eq = Type3Class('Eq', ['a'], methods={}, operators={
'==': 'a -> a -> bool',
'!=': 'a -> a -> bool',
# FIXME: Do we want to expose 'eqz'? Or is that a compiler optimization?
})
Ord = Type3Class('Ord', ['a'], methods={
'min': 'a -> a -> a',
'max': 'a -> a -> a',
}, operators={
'<': 'a -> a -> bool',
'<=': 'a -> a -> bool',
'>': 'a -> a -> bool',
'>=': 'a -> a -> bool',
}, inherited_classes=[Eq])
Bits = Type3Class('Bits', ['a'], methods={
'shl': 'a -> u32 -> a', # Logical shift left
'shr': 'a -> u32 -> a', # Logical shift right
'rotl': 'a -> u32 -> a', # Rotate bits left
'rotr': 'a -> u32 -> a', # Rotate bits right
# FIXME: Do we want to expose clz, ctz, popcnt?
}, operators={
'&': 'a -> a -> a', # Bit-wise and
'|': 'a -> a -> a', # Bit-wise or
'^': 'a -> a -> a', # Bit-wise xor
})
NatNum = Type3Class('NatNum', ['a'], methods={}, operators={
'+': 'a -> a -> a',
'-': 'a -> a -> a',
'*': 'a -> a -> a',
'<<': 'a -> u32 -> a', # Arithmic shift left
'>>': 'a -> u32 -> a', # Arithmic shift right
})
IntNum = Type3Class('IntNum', ['a'], methods={
'abs': 'a -> a',
'neg': 'a -> a',
}, operators={}, inherited_classes=[NatNum])
Integral = Type3Class('Eq', ['a'], methods={
}, operators={
'//': 'a -> a -> a',
'%': 'a -> a -> a',
}, inherited_classes=[NatNum])
Fractional = Type3Class('Fractional', ['a'], methods={
'ceil': 'a -> a',
'floor': 'a -> a',
'trunc': 'a -> a',
'nearest': 'a -> a',
}, operators={
'/': 'a -> a -> a',
}, inherited_classes=[NatNum])
Floating = Type3Class('Floating', ['a'], methods={
'sqrt': 'a -> a',
}, operators={}, inherited_classes=[Fractional])
# FIXME: Do we want to expose copysign?