phasm/phasm/type3/constraints.py
Johan B.W. de Vries faaf7912b1 Various cleanup to type system
- Make struct into a type constuctor
- Rework placeholders
- Got rid of 'PrimitiveType' as a concept
- Moved out the prelude to its own folder
2025-04-21 16:49:04 +02:00

527 lines
18 KiB
Python

"""
This module contains possible constraints generated based on the AST
These need to be resolved before the program can be compiled.
"""
from typing import Dict, List, Optional, Tuple, Union
from .. import ourlang, prelude
from . import placeholders, typeclasses, types
class Error:
"""
An error returned by the check functions for a contraint
This means the programmer has to make some kind of chance to the
typing of their program before the compiler can do its thing.
"""
def __init__(self, msg: str, *, comment: Optional[str] = None) -> None:
self.msg = msg
self.comment = comment
def __repr__(self) -> str:
return f'Error({repr(self.msg)}, comment={repr(self.comment)})'
class RequireTypeSubstitutes:
"""
Returned by the check function for a contraint if they do not have all
their types substituted yet.
Hopefully, another constraint will give the right information about the
typing of the program, so this constraint can be updated.
"""
SubstitutionMap = Dict[placeholders.PlaceholderForType, types.Type3]
NewConstraintList = List['ConstraintBase']
CheckResult = Union[None, SubstitutionMap, Error, NewConstraintList, RequireTypeSubstitutes]
HumanReadableRet = Tuple[str, Dict[str, Union[str, ourlang.Expression, types.Type3, placeholders.PlaceholderForType]]]
class Context:
"""
Context for constraints
"""
__slots__ = ()
class ConstraintBase:
"""
Base class for constraints
"""
__slots__ = ('comment', )
comment: Optional[str]
"""
A comment to help the programmer with debugging the types in their program
"""
def __init__(self, comment: Optional[str] = None) -> None:
self.comment = comment
def check(self) -> CheckResult:
"""
Checks if the constraint hold
This function can return an error, if the constraint does not hold,
which indicates an error in the typing of the input program.
This function can return RequireTypeSubstitutes(), if we cannot deduce
all the types yet.
This function can return a SubstitutionMap, if during the evaluation
of the contraint we discovered new types. In this case, the constraint
is expected to hold.
This function can return None, if the constraint holds, but no new
information was deduced from evaluating this constraint.
"""
raise NotImplementedError(self.__class__, self.check)
def human_readable(self) -> HumanReadableRet:
"""
Returns a more human readable form of this constraint
"""
return repr(self), {}
class SameTypeConstraint(ConstraintBase):
"""
Verifies that a number of types all are the same type
"""
__slots__ = ('type_list', )
type_list: List[placeholders.Type3OrPlaceholder]
def __init__(self, *type_list: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
assert len(type_list) > 1
self.type_list = [*type_list]
def check(self) -> CheckResult:
known_types: List[types.Type3] = []
phft_list = []
for typ in self.type_list:
if isinstance(typ, types.Type3):
known_types.append(typ)
continue
if isinstance(typ, placeholders.PlaceholderForType):
if typ.resolve_as is not None:
known_types.append(typ.resolve_as)
else:
phft_list.append(typ)
continue
raise NotImplementedError(typ)
if not known_types:
return RequireTypeSubstitutes()
first_type = known_types[0]
for ktyp in known_types[1:]:
if ktyp != first_type:
return Error(f'{ktyp:s} must be {first_type:s} instead', comment=self.comment)
if not placeholders:
return None
for phft in phft_list:
phft.resolve_as = first_type
return {
typ: first_type
for typ in phft_list
}
def human_readable(self) -> HumanReadableRet:
return (
' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))),
{
't' + str(idx): typ
for idx, typ in enumerate(self.type_list)
},
)
def __repr__(self) -> str:
args = ', '.join(repr(x) for x in self.type_list)
return f'SameTypeConstraint({args}, comment={repr(self.comment)})'
class TupleMatchConstraint(ConstraintBase):
def __init__(self, exp_type: placeholders.Type3OrPlaceholder, args: List[placeholders.Type3OrPlaceholder], comment: str):
super().__init__(comment=comment)
self.exp_type = exp_type
self.args = list(args)
def check(self) -> CheckResult:
exp_type = self.exp_type
if isinstance(exp_type, placeholders.PlaceholderForType):
if exp_type.resolve_as is None:
return RequireTypeSubstitutes()
exp_type = exp_type.resolve_as
assert isinstance(exp_type, types.Type3)
sa_args = prelude.static_array.did_construct(exp_type)
if sa_args is not None:
sa_type, sa_len = sa_args
if sa_len.value != len(self.args):
return Error('Mismatch between applied types argument count', comment=self.comment)
return [
SameTypeConstraint(arg, sa_type)
for arg in self.args
]
tp_args = prelude.tuple_.did_construct(exp_type)
if tp_args is not None:
if len(tp_args) != len(self.args):
return Error('Mismatch between applied types argument count', comment=self.comment)
return [
SameTypeConstraint(arg, oth_arg)
for arg, oth_arg in zip(self.args, tp_args)
]
raise NotImplementedError(exp_type)
class CastableConstraint(ConstraintBase):
"""
A type can be cast to another type
"""
__slots__ = ('from_type3', 'to_type3', )
from_type3: placeholders.Type3OrPlaceholder
to_type3: placeholders.Type3OrPlaceholder
def __init__(self, from_type3: placeholders.Type3OrPlaceholder, to_type3: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
self.from_type3 = from_type3
self.to_type3 = to_type3
def check(self) -> CheckResult:
ftyp = self.from_type3
if isinstance(ftyp, placeholders.PlaceholderForType) and ftyp.resolve_as is not None:
ftyp = ftyp.resolve_as
ttyp = self.to_type3
if isinstance(ttyp, placeholders.PlaceholderForType) and ttyp.resolve_as is not None:
ttyp = ttyp.resolve_as
if isinstance(ftyp, placeholders.PlaceholderForType) or isinstance(ttyp, placeholders.PlaceholderForType):
return RequireTypeSubstitutes()
if ftyp is prelude.u8 and ttyp is prelude.u32:
return None
return Error(f'Cannot cast {ftyp.name} to {ttyp.name}')
def human_readable(self) -> HumanReadableRet:
return (
'{to_type3}({from_type3})',
{
'to_type3': self.to_type3,
'from_type3': self.from_type3,
},
)
def __repr__(self) -> str:
return f'CastableConstraint({repr(self.from_type3)}, {repr(self.to_type3)}, comment={repr(self.comment)})'
class MustImplementTypeClassConstraint(ConstraintBase):
"""
A type must implement a given type class
"""
__slots__ = ('type_class3', 'type3', )
type_class3: Union[str, typeclasses.Type3Class]
type3: placeholders.Type3OrPlaceholder
DATA = {
'bytes': {'Foldable', 'Sized'},
}
def __init__(self, type_class3: Union[str, typeclasses.Type3Class], type3: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
self.type_class3 = type_class3
self.type3 = type3
def check(self) -> CheckResult:
typ = self.type3
if isinstance(typ, placeholders.PlaceholderForType) and typ.resolve_as is not None:
typ = typ.resolve_as
if isinstance(typ, placeholders.PlaceholderForType):
return RequireTypeSubstitutes()
if isinstance(self.type_class3, typeclasses.Type3Class):
if self.type_class3 in typ.classes:
return None
else:
if self.type_class3 in self.__class__.DATA.get(typ.name, set()):
return None
return Error(f'{typ.name} does not implement the {self.type_class3} type class')
def human_readable(self) -> HumanReadableRet:
return (
'{type3} derives {type_class3}',
{
'type_class3': str(self.type_class3),
'type3': self.type3,
},
)
def __repr__(self) -> str:
return f'MustImplementTypeClassConstraint({repr(self.type_class3)}, {repr(self.type3)}, comment={repr(self.comment)})'
class LiteralFitsConstraint(ConstraintBase):
"""
A literal value fits a given type
"""
__slots__ = ('type3', 'literal', )
type3: placeholders.Type3OrPlaceholder
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantBytes, ourlang.ConstantTuple, ourlang.ConstantStruct]
def __init__(
self,
type3: placeholders.Type3OrPlaceholder,
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantBytes, ourlang.ConstantTuple, ourlang.ConstantStruct],
comment: Optional[str] = None,
) -> None:
super().__init__(comment=comment)
self.type3 = type3
self.literal = literal
def check(self) -> CheckResult:
int_table: Dict[str, Tuple[int, bool]] = {
'u8': (1, False),
'u32': (4, False),
'u64': (8, False),
'i8': (1, True),
'i32': (4, True),
'i64': (8, True),
}
float_table: Dict[str, None] = {
'f32': None,
'f64': None,
}
if isinstance(self.type3, placeholders.PlaceholderForType):
if self.type3.resolve_as is None:
return RequireTypeSubstitutes()
self.type3 = self.type3.resolve_as
if self.type3.name in int_table:
bts, sgn = int_table[self.type3.name]
if isinstance(self.literal.value, int):
try:
self.literal.value.to_bytes(bts, 'big', signed=sgn)
except OverflowError:
return Error(f'Must fit in {bts} byte(s)', comment=self.comment) # FIXME: Add line information
return None
return Error('Must be integer', comment=self.comment) # FIXME: Add line information
if self.type3.name in float_table:
_ = float_table[self.type3.name]
if isinstance(self.literal.value, float):
# FIXME: Bit check
return None
return Error('Must be real', comment=self.comment) # FIXME: Add line information
if self.type3 is prelude.bytes_:
if isinstance(self.literal.value, bytes):
return None
return Error('Must be bytes', comment=self.comment) # FIXME: Add line information
res: NewConstraintList
assert isinstance(self.type3, types.Type3)
tp_args = prelude.tuple_.did_construct(self.type3)
if tp_args is not None:
if not isinstance(self.literal, ourlang.ConstantTuple):
return Error('Must be tuple', comment=self.comment)
if len(tp_args) != len(self.literal.value):
return Error('Tuple element count mismatch', comment=self.comment)
res = []
res.extend(
LiteralFitsConstraint(x, y)
for x, y in zip(tp_args, self.literal.value)
)
res.extend(
SameTypeConstraint(x, y.type3)
for x, y in zip(tp_args, self.literal.value)
)
return res
sa_args = prelude.static_array.did_construct(self.type3)
if sa_args is not None:
if not isinstance(self.literal, ourlang.ConstantTuple):
return Error('Must be tuple', comment=self.comment)
sa_type, sa_len = sa_args
if sa_len.value != len(self.literal.value):
return Error('Member count mismatch', comment=self.comment)
res = []
res.extend(
LiteralFitsConstraint(sa_type, y)
for y in self.literal.value
)
res.extend(
SameTypeConstraint(sa_type, y.type3)
for y in self.literal.value
)
return res
st_args = prelude.struct.did_construct(self.type3)
if st_args is not None:
if not isinstance(self.literal, ourlang.ConstantStruct):
return Error('Must be struct')
if self.literal.struct_name != self.type3.name:
return Error('Struct mismatch')
if len(st_args) != len(self.literal.value):
return Error('Struct element count mismatch')
res = []
res.extend(
LiteralFitsConstraint(x, y)
for x, y in zip(st_args.values(), self.literal.value)
)
res.extend(
SameTypeConstraint(x_t, y.type3, comment=f'{self.literal.struct_name}.{x_n}')
for (x_n, x_t, ), y in zip(st_args.items(), self.literal.value)
)
return res
raise NotImplementedError(self.type3, self.literal)
def human_readable(self) -> HumanReadableRet:
return (
'{literal} : {type3}',
{
'literal': self.literal,
'type3': self.type3,
},
)
def __repr__(self) -> str:
return f'LiteralFitsConstraint({repr(self.type3)}, {repr(self.literal)}, comment={repr(self.comment)})'
class CanBeSubscriptedConstraint(ConstraintBase):
"""
A value that is subscipted, i.e. a[0] (tuple) or a[b] (static array)
"""
__slots__ = ('ret_type3', 'type3', 'index', 'index_type3', )
ret_type3: placeholders.Type3OrPlaceholder
type3: placeholders.Type3OrPlaceholder
index: ourlang.Expression
index_type3: placeholders.Type3OrPlaceholder
def __init__(self, ret_type3: placeholders.Type3OrPlaceholder, type3: placeholders.Type3OrPlaceholder, index: ourlang.Expression, comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
self.ret_type3 = ret_type3
self.type3 = type3
self.index = index
self.index_type3 = index.type3
def check(self) -> CheckResult:
exp_type = self.type3
if isinstance(exp_type, placeholders.PlaceholderForType):
if exp_type.resolve_as is None:
return RequireTypeSubstitutes()
exp_type = exp_type.resolve_as
assert isinstance(exp_type, types.Type3)
sa_args = prelude.static_array.did_construct(exp_type)
if sa_args is not None:
sa_type, sa_len = sa_args
result: List[ConstraintBase] = [
SameTypeConstraint(prelude.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
SameTypeConstraint(sa_type, self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
]
if isinstance(self.index, ourlang.ConstantPrimitive):
assert isinstance(self.index.value, int)
if self.index.value < 0 or sa_len.value <= self.index.value:
return Error('Tuple index out of range')
return result
# We special case tuples to allow for ease of use to the programmer
# e.g. rather than having to do `fst a` and `snd a` and only have to-sized tuples
# we use a[0] and a[1] and allow for a[2] and on.
tp_args = prelude.tuple_.did_construct(exp_type)
if tp_args is not None:
if not isinstance(self.index, ourlang.ConstantPrimitive):
return Error('Must index with literal')
if not isinstance(self.index.value, int):
return Error('Must index with integer literal')
if self.index.value < 0 or len(tp_args) <= self.index.value:
return Error('Tuple index out of range')
return [
SameTypeConstraint(prelude.u32, self.index_type3, comment=f'Tuple subscript index {self.index.value}'),
SameTypeConstraint(tp_args[self.index.value], self.ret_type3, comment=f'Tuple subscript index {self.index.value}'),
]
if exp_type is prelude.bytes_:
return [
SameTypeConstraint(prelude.u32, self.index_type3, comment='([]) :: bytes -> u32 -> u8'),
SameTypeConstraint(prelude.u8, self.ret_type3, comment='([]) :: bytes -> u32 -> u8'),
]
return Error(f'{exp_type.name} cannot be subscripted')
def human_readable(self) -> HumanReadableRet:
return (
'{type3}[{index}]',
{
'type3': self.type3,
'index': self.index,
},
)
def __repr__(self) -> str:
return f'CanBeSubscriptedConstraint({repr(self.type3)}, {repr(self.index)}, comment={repr(self.comment)})'