Started work on applied type constraints

This commit is contained in:
Johan B.W. de Vries 2022-12-31 15:31:39 +01:00
parent e936d6e885
commit 7f0acf00fe
3 changed files with 88 additions and 28 deletions

View File

@ -104,27 +104,58 @@ class SameTypeConstraint(ConstraintBase):
self.type_list = [*type_list] self.type_list = [*type_list]
def check(self, smap: SubstitutionMap) -> CheckResult: def check(self, smap: SubstitutionMap) -> CheckResult:
known_types = [] known_types: List[types.Type3] = []
placeholders = [] placeholders = []
do_applied_placeholder_check: bool = False
for typ in self.type_list: for typ in self.type_list:
if isinstance(typ, types.Type3): if isinstance(typ, types.PrimitiveType3):
known_types.append(typ) known_types.append(typ)
continue continue
if typ in smap: if isinstance(typ, types.AppliedType3):
known_types.append(smap[typ]) known_types.append(typ)
do_applied_placeholder_check = True
continue continue
placeholders.append(typ) if isinstance(typ, types.PlaceholderForType):
if typ in smap:
known_types.append(smap[typ])
else:
placeholders.append(typ)
continue
raise NotImplementedError(typ)
if not known_types: if not known_types:
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
new_constraint_list: List[ConstraintBase] = []
first_type = known_types[0] first_type = known_types[0]
for typ in known_types[1:]: for typ in known_types[1:]:
if isinstance(first_type, types.AppliedType3) and isinstance(typ, types.AppliedType3):
if len(first_type.args) != len(typ.args):
return Error('Mismatch between applied types argument count')
if first_type.base != typ.base:
return Error('Mismatch between applied types base')
for first_type_arg, typ_arg in zip(first_type.args, typ.args):
new_constraint_list.append(SameTypeConstraint(
first_type_arg, typ_arg
))
continue
if typ != first_type: if typ != first_type:
return Error(f'{typ:s} must be {first_type:s} instead') return Error(f'{typ:s} must be {first_type:s} instead')
if new_constraint_list:
# If this happens, make CheckResult a class that can have both
assert not placeholders, 'Cannot (yet) return both new placeholders and new constraints'
return new_constraint_list
if not placeholders: if not placeholders:
return None return None

View File

@ -8,7 +8,7 @@ from .. import ourlang
from .constraints import ConstraintBase, Error, RequireTypeSubstitutes, SameTypeConstraint, SubstitutionMap from .constraints import ConstraintBase, Error, RequireTypeSubstitutes, SameTypeConstraint, SubstitutionMap
from .constraintsgenerator import phasm_type3_generate_constraints from .constraintsgenerator import phasm_type3_generate_constraints
from .types import PlaceholderForType, Type3 from .types import AppliedType3, PlaceholderForType, PrimitiveType3, Type3, Type3OrPlaceholder
MAX_RESTACK_COUNT = 100 MAX_RESTACK_COUNT = 100
@ -111,14 +111,8 @@ def print_constraint(placeholder_id_map: Dict[int, str], constraint: ConstraintB
if isinstance(fmt_val, ourlang.Expression): if isinstance(fmt_val, ourlang.Expression):
fmt_val = codestyle.expression(fmt_val) fmt_val = codestyle.expression(fmt_val)
if isinstance(fmt_val, Type3): if isinstance(fmt_val, Type3) or isinstance(fmt_val, PlaceholderForType):
fmt_val = fmt_val.name fmt_val = get_printable_type_name(fmt_val, placeholder_id_map)
if isinstance(fmt_val, PlaceholderForType):
placeholder_id = id(fmt_val)
if placeholder_id not in placeholder_id_map:
placeholder_id_map[placeholder_id] = 'T' + str(len(placeholder_id_map) + 1)
fmt_val = placeholder_id_map[placeholder_id]
if not isinstance(fmt_val, str): if not isinstance(fmt_val, str):
fmt_val = repr(fmt_val) fmt_val = repr(fmt_val)
@ -130,6 +124,26 @@ def print_constraint(placeholder_id_map: Dict[int, str], constraint: ConstraintB
else: else:
print('- ' + txt.format(**act_fmt)) print('- ' + txt.format(**act_fmt))
def get_printable_type_name(inp: Type3OrPlaceholder, placeholder_id_map: Dict[int, str]) -> str:
if isinstance(inp, PrimitiveType3):
return inp.name
if isinstance(inp, PlaceholderForType):
placeholder_id = id(inp)
if placeholder_id not in placeholder_id_map:
placeholder_id_map[placeholder_id] = 'T' + str(len(placeholder_id_map) + 1)
return placeholder_id_map[placeholder_id]
if isinstance(inp, AppliedType3):
return (
get_printable_type_name(inp.base, placeholder_id_map)
+ ' ('
+ ') ('.join(get_printable_type_name(x, placeholder_id_map) for x in inp.args)
+ ')'
)
raise NotImplementedError(inp)
def print_constraint_list(placeholder_id_map: Dict[int, str], constraint_list: List[ConstraintBase], placeholder_substitutes: SubstitutionMap) -> None: def print_constraint_list(placeholder_id_map: Dict[int, str], constraint_list: List[ConstraintBase], placeholder_substitutes: SubstitutionMap) -> None:
print('=== v type3 constraint_list v === ') print('=== v type3 constraint_list v === ')
for psk, psv in placeholder_substitutes.items(): for psk, psv in placeholder_substitutes.items():

View File

@ -62,6 +62,13 @@ class Type3:
def __bool__(self) -> bool: def __bool__(self) -> bool:
raise NotImplementedError raise NotImplementedError
class PrimitiveType3(Type3):
"""
Intermediate class to tell primitive types from others
"""
__slots__ = ()
class PlaceholderForType: class PlaceholderForType:
""" """
A placeholder type, for when we don't know the final type yet A placeholder type, for when we don't know the final type yet
@ -113,7 +120,7 @@ class AppliedType3(Type3):
""" """
__slots__ = ('base', 'args', ) __slots__ = ('base', 'args', )
base: Type3 base: PrimitiveType3
""" """
The base type The base type
""" """
@ -123,8 +130,9 @@ class AppliedType3(Type3):
The applied types (or placeholders there for) The applied types (or placeholders there for)
""" """
def __init__(self, base: Type3, args: Iterable[Type3OrPlaceholder]) -> None: def __init__(self, base: PrimitiveType3, args: Iterable[Type3OrPlaceholder]) -> None:
args = [*args] args = [*args]
assert args, 'Must at least one argument'
super().__init__( super().__init__(
base.name base.name
@ -136,6 +144,13 @@ class AppliedType3(Type3):
self.base = base self.base = base
self.args = args self.args = args
@property
def has_placeholders(self) -> bool:
return any(
isinstance(x, PlaceholderForType)
for x in self.args
)
def __eq__(self, other: Any) -> bool: def __eq__(self, other: Any) -> bool:
if not isinstance(other, Type3): if not isinstance(other, Type3):
raise NotImplementedError raise NotImplementedError
@ -180,33 +195,33 @@ class StructType3(Type3):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'StructType3(repr({self.name}), repr({self.members}))' return f'StructType3(repr({self.name}), repr({self.members}))'
none = Type3('none') none = PrimitiveType3('none')
""" """
The none type, for when functions simply don't return anything. e.g., IO(). The none type, for when functions simply don't return anything. e.g., IO().
""" """
u8 = Type3('u8') u8 = PrimitiveType3('u8')
""" """
The unsigned 8-bit integer type. The unsigned 8-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^8. Operations on variables employ modular arithmetic, with modulus 2^8.
""" """
u32 = Type3('u32') u32 = PrimitiveType3('u32')
""" """
The unsigned 32-bit integer type. The unsigned 32-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^32. Operations on variables employ modular arithmetic, with modulus 2^32.
""" """
u64 = Type3('u64') u64 = PrimitiveType3('u64')
""" """
The unsigned 64-bit integer type. The unsigned 64-bit integer type.
Operations on variables employ modular arithmetic, with modulus 2^64. Operations on variables employ modular arithmetic, with modulus 2^64.
""" """
i8 = Type3('i8') i8 = PrimitiveType3('i8')
""" """
The signed 8-bit integer type. The signed 8-bit integer type.
@ -214,7 +229,7 @@ Operations on variables employ modular arithmetic, with modulus 2^8, but
with the middel point being 0. with the middel point being 0.
""" """
i32 = Type3('i32') i32 = PrimitiveType3('i32')
""" """
The unsigned 32-bit integer type. The unsigned 32-bit integer type.
@ -222,7 +237,7 @@ Operations on variables employ modular arithmetic, with modulus 2^32, but
with the middel point being 0. with the middel point being 0.
""" """
i64 = Type3('i64') i64 = PrimitiveType3('i64')
""" """
The unsigned 64-bit integer type. The unsigned 64-bit integer type.
@ -230,22 +245,22 @@ Operations on variables employ modular arithmetic, with modulus 2^64, but
with the middel point being 0. with the middel point being 0.
""" """
f32 = Type3('f32') f32 = PrimitiveType3('f32')
""" """
A 32-bits IEEE 754 float, of 32 bits width. A 32-bits IEEE 754 float, of 32 bits width.
""" """
f64 = Type3('f64') f64 = PrimitiveType3('f64')
""" """
A 32-bits IEEE 754 float, of 64 bits width. A 32-bits IEEE 754 float, of 64 bits width.
""" """
bytes = Type3('bytes') bytes = PrimitiveType3('bytes')
""" """
This is a runtime-determined length piece of memory that can be indexed at runtime. This is a runtime-determined length piece of memory that can be indexed at runtime.
""" """
static_array = Type3('static_array') static_array = PrimitiveType3('static_array')
""" """
This is a fixed length piece of memory that can be indexed at runtime. This is a fixed length piece of memory that can be indexed at runtime.
@ -253,7 +268,7 @@ It should be applied with one argument. It has a runtime-dynamic length
of the same type repeated. of the same type repeated.
""" """
tuple = Type3('tuple') # pylint: disable=W0622 tuple = PrimitiveType3('tuple') # pylint: disable=W0622
""" """
This is a fixed length piece of memory. This is a fixed length piece of memory.