Compare commits

...

3 Commits

Author SHA1 Message Date
Johan B.W. de Vries
ac4b46bbe7 Fix: You could assign structs to each other
As long as the arguments matched at least.
2025-05-12 20:00:56 +02:00
Johan B.W. de Vries
67af569448 Cleanup CanBeSubscriptedConstraint
It was using an AST argument, and I'd rather not have those
in the typing system (except the generator).
2025-05-12 19:20:50 +02:00
Johan B.W. de Vries
df5c1911bf Cleans up imports
Rather than accessing the classes from the module, this MR
cleans up the code a bit to use the classes directly.
2025-05-12 18:51:19 +02:00
8 changed files with 175 additions and 122 deletions

View File

@ -63,7 +63,7 @@ def expression(inp: ourlang.Expression) -> str:
) + ', )' ) + ', )'
if isinstance(inp, ourlang.ConstantStruct): if isinstance(inp, ourlang.ConstantStruct):
return inp.struct_name + '(' + ', '.join( return inp.struct_type3.name + '(' + ', '.join(
expression(x) expression(x)
for x in inp.value for x in inp.value
) + ')' ) + ')'

View File

@ -8,10 +8,18 @@ from . import codestyle, ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
from .type3 import functions as type3functions from .type3.functions import TypeVariable
from .type3 import typeclasses as type3classes
from .type3 import types as type3types
from .type3.routers import NoRouteForTypeException, TypeApplicationRouter from .type3.routers import NoRouteForTypeException, TypeApplicationRouter
from .type3.typeclasses import Type3ClassMethod
from .type3.types import (
IntType3,
Type3,
TypeApplication_Struct,
TypeApplication_TypeInt,
TypeApplication_TypeStar,
TypeConstructor_StaticArray,
TypeConstructor_Tuple,
)
from .wasmgenerator import Generator as WasmGenerator from .wasmgenerator import Generator as WasmGenerator
TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before your program can be compiled' TYPE3_ASSERTION_ERROR = 'You must call phasm_type3 after calling phasm_parse before your program can be compiled'
@ -37,7 +45,7 @@ def phasm_compile(inp: ourlang.Module) -> wasm.Module:
""" """
return module(inp) return module(inp)
def type3(inp: type3types.Type3) -> wasm.WasmType: def type3(inp: Type3) -> wasm.WasmType:
""" """
Compile: type Compile: type
@ -98,18 +106,18 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) ->
""" """
assert inp.type3 is not None, TYPE3_ASSERTION_ERROR assert inp.type3 is not None, TYPE3_ASSERTION_ERROR
args: tuple[type3types.Type3, ...] args: tuple[Type3, ...]
if isinstance(inp.type3.application, type3types.TypeApplication_TypeStar): if isinstance(inp.type3.application, TypeApplication_TypeStar):
# Possibly paranoid assert. If we have a future variadic type, # Possibly paranoid assert. If we have a future variadic type,
# does it also do this tuple instantation like this? # does it also do this tuple instantation like this?
assert isinstance(inp.type3.application.constructor, type3types.TypeConstructor_Tuple) assert isinstance(inp.type3.application.constructor, TypeConstructor_Tuple)
args = inp.type3.application.arguments args = inp.type3.application.arguments
elif isinstance(inp.type3.application, type3types.TypeApplication_TypeInt): elif isinstance(inp.type3.application, TypeApplication_TypeInt):
# Possibly paranoid assert. If we have a future type of kind * -> Int -> *, # Possibly paranoid assert. If we have a future type of kind * -> Int -> *,
# does it also do this tuple instantation like this? # does it also do this tuple instantation like this?
assert isinstance(inp.type3.application.constructor, type3types.TypeConstructor_StaticArray) assert isinstance(inp.type3.application.constructor, TypeConstructor_StaticArray)
sa_type, sa_len = inp.type3.application.arguments sa_type, sa_len = inp.type3.application.arguments
@ -162,7 +170,7 @@ def expression_subscript_bytes(
def expression_subscript_static_array( def expression_subscript_static_array(
attrs: tuple[WasmGenerator, ourlang.Subscript], attrs: tuple[WasmGenerator, ourlang.Subscript],
args: tuple[type3types.Type3, type3types.IntType3], args: tuple[Type3, IntType3],
) -> None: ) -> None:
wgn, inp = attrs wgn, inp = attrs
@ -194,7 +202,7 @@ def expression_subscript_static_array(
def expression_subscript_tuple( def expression_subscript_tuple(
attrs: tuple[WasmGenerator, ourlang.Subscript], attrs: tuple[WasmGenerator, ourlang.Subscript],
args: tuple[type3types.Type3, ...], args: tuple[Type3, ...],
) -> None: ) -> None:
wgn, inp = attrs wgn, inp = attrs
@ -292,16 +300,16 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
expression(wgn, inp.left) expression(wgn, inp.left)
expression(wgn, inp.right) expression(wgn, inp.right)
type_var_map: dict[type3functions.TypeVariable, type3types.Type3] = {} type_var_map: dict[TypeVariable, Type3] = {}
for type_var, arg_expr in zip(inp.operator.signature.args, [inp.left, inp.right, inp], strict=True): for type_var, arg_expr in zip(inp.operator.signature.args, [inp.left, inp.right, inp], strict=True):
assert arg_expr.type3 is not None, TYPE3_ASSERTION_ERROR assert arg_expr.type3 is not None, TYPE3_ASSERTION_ERROR
if isinstance(type_var, type3types.Type3): if isinstance(type_var, Type3):
# Fixed type, not part of the lookup requirements # Fixed type, not part of the lookup requirements
continue continue
if isinstance(type_var, type3functions.TypeVariable): if isinstance(type_var, TypeVariable):
type_var_map[type_var] = arg_expr.type3 type_var_map[type_var] = arg_expr.type3
continue continue
@ -315,18 +323,18 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
for arg in inp.arguments: for arg in inp.arguments:
expression(wgn, arg) expression(wgn, arg)
if isinstance(inp.function, type3classes.Type3ClassMethod): if isinstance(inp.function, Type3ClassMethod):
# FIXME: Duplicate code with BinaryOp # FIXME: Duplicate code with BinaryOp
type_var_map = {} type_var_map = {}
for type_var, arg_expr in zip(inp.function.signature.args, inp.arguments + [inp], strict=True): for type_var, arg_expr in zip(inp.function.signature.args, inp.arguments + [inp], strict=True):
assert arg_expr.type3 is not None, TYPE3_ASSERTION_ERROR assert arg_expr.type3 is not None, TYPE3_ASSERTION_ERROR
if isinstance(type_var, type3types.Type3): if isinstance(type_var, Type3):
# Fixed type, not part of the lookup requirements # Fixed type, not part of the lookup requirements
continue continue
if isinstance(type_var, type3functions.TypeVariable): if isinstance(type_var, TypeVariable):
type_var_map[type_var] = arg_expr.type3 type_var_map[type_var] = arg_expr.type3
continue continue
@ -356,7 +364,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
assert inp.struct_type3 is not None, TYPE3_ASSERTION_ERROR assert inp.struct_type3 is not None, TYPE3_ASSERTION_ERROR
assert isinstance(inp.struct_type3.application, type3types.TypeApplication_Struct) assert isinstance(inp.struct_type3.application, TypeApplication_Struct)
member_type = dict(inp.struct_type3.application.arguments)[inp.member] member_type = dict(inp.struct_type3.application.arguments)[inp.member]
@ -739,7 +747,7 @@ def module(inp: ourlang.Module) -> wasm.Module:
return result return result
def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None:
assert isinstance(inp.struct_type3.application, type3types.TypeApplication_Struct) assert isinstance(inp.struct_type3.application, TypeApplication_Struct)
st_args = inp.struct_type3.application.arguments st_args = inp.struct_type3.application.arguments

View File

@ -98,21 +98,21 @@ class ConstantStruct(ConstantMemoryStored):
""" """
A Struct constant value expression within a statement A Struct constant value expression within a statement
""" """
__slots__ = ('struct_name', 'value', ) __slots__ = ('struct_type3', 'value', )
struct_name: str struct_type3: Type3
value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']] value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']]
def __init__(self, struct_name: str, value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']], data_block: 'ModuleDataBlock') -> None: def __init__(self, struct_type3: Type3, value: List[Union[ConstantPrimitive, ConstantBytes, ConstantTuple, 'ConstantStruct']], data_block: 'ModuleDataBlock') -> None:
super().__init__(data_block) super().__init__(data_block)
self.struct_name = struct_name self.struct_type3 = struct_type3
self.value = value self.value = value
def __repr__(self) -> str: def __repr__(self) -> str:
# Do not repr the whole ModuleDataBlock # Do not repr the whole ModuleDataBlock
# As this has a reference back to this constant for its data # As this has a reference back to this constant for its data
# which it needs to compile the data into the program # which it needs to compile the data into the program
return f'ConstantStruct({repr(self.struct_name)}, {repr(self.value)}, @{repr(self.data_block.address)})' return f'ConstantStruct({self.struct_type3!r}, {self.value!r}, @{self.data_block.address!r})'
class VariableReference(Expression): class VariableReference(Expression):
""" """

View File

@ -32,8 +32,8 @@ from .ourlang import (
VariableReference, VariableReference,
) )
from .prelude import PRELUDE_METHODS, PRELUDE_OPERATORS, PRELUDE_TYPES from .prelude import PRELUDE_METHODS, PRELUDE_OPERATORS, PRELUDE_TYPES
from .type3 import typeclasses as type3typeclasses from .type3.typeclasses import Type3ClassMethod
from .type3 import types as type3types from .type3.types import IntType3, Type3
def phasm_parse(source: str) -> Module: def phasm_parse(source: str) -> Module:
@ -226,7 +226,7 @@ class OurVisitor:
_not_implemented(not node.keywords, 'ClassDef.keywords') _not_implemented(not node.keywords, 'ClassDef.keywords')
_not_implemented(not node.decorator_list, 'ClassDef.decorator_list') _not_implemented(not node.decorator_list, 'ClassDef.decorator_list')
members: Dict[str, type3types.Type3] = {} members: Dict[str, Type3] = {}
for stmt in node.body: for stmt in node.body:
if not isinstance(stmt, ast.AnnAssign): if not isinstance(stmt, ast.AnnAssign):
@ -352,7 +352,7 @@ class OurVisitor:
def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression: def visit_Module_FunctionDef_expr(self, module: Module, function: Function, our_locals: OurLocals, node: ast.expr) -> Expression:
if isinstance(node, ast.BinOp): if isinstance(node, ast.BinOp):
operator: Union[str, type3typeclasses.Type3ClassMethod] operator: Union[str, Type3ClassMethod]
if isinstance(node.op, ast.Add): if isinstance(node.op, ast.Add):
operator = '+' operator = '+'
@ -471,7 +471,7 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load): if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
func: Union[Function, type3typeclasses.Type3ClassMethod] func: Union[Function, Type3ClassMethod]
if node.func.id in PRELUDE_METHODS: if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id] func = PRELUDE_METHODS[node.func.id]
@ -584,7 +584,8 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load): if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node.func, 'Must be load context') _raise_static_error(node.func, 'Must be load context')
if node.func.id not in module.struct_definitions: struct_def = module.struct_definitions.get(node.func.id)
if struct_def is None:
_raise_static_error(node.func, 'Undefined struct') _raise_static_error(node.func, 'Undefined struct')
if node.keywords: if node.keywords:
@ -600,7 +601,7 @@ class OurVisitor:
data_block = ModuleDataBlock(struct_data) data_block = ModuleDataBlock(struct_data)
module.data.blocks.append(data_block) module.data.blocks.append(data_block)
return ConstantStruct(node.func.id, struct_data, data_block) return ConstantStruct(struct_def.struct_type3, struct_data, data_block)
_not_implemented(node.kind is None, 'Constant.kind') _not_implemented(node.kind is None, 'Constant.kind')
@ -617,7 +618,7 @@ class OurVisitor:
raise NotImplementedError(f'{node.value} as constant') raise NotImplementedError(f'{node.value} as constant')
def visit_type(self, module: Module, node: ast.expr) -> type3types.Type3: def visit_type(self, module: Module, node: ast.expr) -> Type3:
if isinstance(node, ast.Constant): if isinstance(node, ast.Constant):
if node.value is None: if node.value is None:
return prelude.none return prelude.none
@ -645,7 +646,7 @@ class OurVisitor:
return prelude.static_array( return prelude.static_array(
self.visit_type(module, node.value), self.visit_type(module, node.value),
type3types.IntType3(node.slice.value), IntType3(node.slice.value),
) )
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):

View File

@ -6,9 +6,19 @@ These need to be resolved before the program can be compiled.
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from .. import ourlang, prelude from .. import ourlang, prelude
from . import placeholders, typeclasses, types from .placeholders import PlaceholderForType, Type3OrPlaceholder
from .placeholders import PlaceholderForType
from .routers import NoRouteForTypeException, TypeApplicationRouter from .routers import NoRouteForTypeException, TypeApplicationRouter
from .typeclasses import Type3Class
from .types import (
IntType3,
Type3,
TypeApplication_Nullary,
TypeApplication_Struct,
TypeApplication_TypeInt,
TypeApplication_TypeStar,
TypeConstructor_Base,
TypeConstructor_Struct,
)
class Error: class Error:
@ -34,13 +44,13 @@ class RequireTypeSubstitutes:
typing of the program, so this constraint can be updated. typing of the program, so this constraint can be updated.
""" """
SubstitutionMap = Dict[placeholders.PlaceholderForType, types.Type3] SubstitutionMap = Dict[PlaceholderForType, Type3]
NewConstraintList = List['ConstraintBase'] NewConstraintList = List['ConstraintBase']
CheckResult = Union[None, SubstitutionMap, Error, NewConstraintList, RequireTypeSubstitutes] CheckResult = Union[None, SubstitutionMap, Error, NewConstraintList, RequireTypeSubstitutes]
HumanReadableRet = Tuple[str, Dict[str, Union[str, ourlang.Expression, types.Type3, placeholders.PlaceholderForType]]] HumanReadableRet = Tuple[str, Dict[str, Union[None, int, str, ourlang.Expression, Type3, PlaceholderForType]]]
class Context: class Context:
""" """
@ -50,7 +60,7 @@ class Context:
__slots__ = ('type_class_instances_existing', ) __slots__ = ('type_class_instances_existing', )
# Constraint_TypeClassInstanceExists # Constraint_TypeClassInstanceExists
type_class_instances_existing: set[tuple[typeclasses.Type3Class, tuple[Union[types.Type3, types.TypeConstructor_Base[Any], types.TypeConstructor_Struct], ...]]] type_class_instances_existing: set[tuple[Type3Class, tuple[Union[Type3, TypeConstructor_Base[Any], TypeConstructor_Struct], ...]]]
def __init__(self) -> None: def __init__(self) -> None:
self.type_class_instances_existing = set() self.type_class_instances_existing = set()
@ -100,23 +110,23 @@ class SameTypeConstraint(ConstraintBase):
""" """
__slots__ = ('type_list', ) __slots__ = ('type_list', )
type_list: List[placeholders.Type3OrPlaceholder] type_list: List[Type3OrPlaceholder]
def __init__(self, *type_list: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None: def __init__(self, *type_list: Type3OrPlaceholder, comment: Optional[str] = None) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
assert len(type_list) > 1 assert len(type_list) > 1
self.type_list = [*type_list] self.type_list = [*type_list]
def check(self) -> CheckResult: def check(self) -> CheckResult:
known_types: List[types.Type3] = [] known_types: List[Type3] = []
phft_list = [] phft_list = []
for typ in self.type_list: for typ in self.type_list:
if isinstance(typ, types.Type3): if isinstance(typ, Type3):
known_types.append(typ) known_types.append(typ)
continue continue
if isinstance(typ, placeholders.PlaceholderForType): if isinstance(typ, PlaceholderForType):
if typ.resolve_as is not None: if typ.resolve_as is not None:
known_types.append(typ.resolve_as) known_types.append(typ.resolve_as)
else: else:
@ -133,7 +143,7 @@ class SameTypeConstraint(ConstraintBase):
if ktyp != first_type: if ktyp != first_type:
return Error(f'{ktyp:s} must be {first_type:s} instead', comment=self.comment) return Error(f'{ktyp:s} must be {first_type:s} instead', comment=self.comment)
if not placeholders: if not phft_list:
return None return None
for phft in phft_list: for phft in phft_list:
@ -177,10 +187,10 @@ class SameTypeArgumentConstraint(ConstraintBase):
tc_typ = self.tc_var.resolve_as tc_typ = self.tc_var.resolve_as
arg_typ = self.arg_var.resolve_as arg_typ = self.arg_var.resolve_as
if isinstance(tc_typ.application, types.TypeApplication_Nullary): if isinstance(tc_typ.application, TypeApplication_Nullary):
return Error(f'{tc_typ:s} must be a constructed type instead') return Error(f'{tc_typ:s} must be a constructed type instead')
if isinstance(tc_typ.application, types.TypeApplication_TypeStar): if isinstance(tc_typ.application, TypeApplication_TypeStar):
# Sure, it's a constructed type. But it's like a struct, # Sure, it's a constructed type. But it's like a struct,
# though without the way to implement type classes # though without the way to implement type classes
# Presumably, doing a naked `foo :: t a -> a` # Presumably, doing a naked `foo :: t a -> a`
@ -190,7 +200,7 @@ class SameTypeArgumentConstraint(ConstraintBase):
# FIXME: This feels sketchy. Shouldn't the type variable # FIXME: This feels sketchy. Shouldn't the type variable
# have the exact same number as arguments? # have the exact same number as arguments?
if isinstance(tc_typ.application, types.TypeApplication_TypeInt): if isinstance(tc_typ.application, TypeApplication_TypeInt):
if tc_typ.application.arguments[0] == arg_typ: if tc_typ.application.arguments[0] == arg_typ:
return None return None
@ -201,16 +211,16 @@ class SameTypeArgumentConstraint(ConstraintBase):
class TupleMatchConstraint(ConstraintBase): class TupleMatchConstraint(ConstraintBase):
__slots__ = ('exp_type', 'args', ) __slots__ = ('exp_type', 'args', )
exp_type: placeholders.Type3OrPlaceholder exp_type: Type3OrPlaceholder
args: list[placeholders.Type3OrPlaceholder] args: list[Type3OrPlaceholder]
def __init__(self, exp_type: placeholders.Type3OrPlaceholder, args: Iterable[placeholders.Type3OrPlaceholder], comment: str): def __init__(self, exp_type: Type3OrPlaceholder, args: Iterable[Type3OrPlaceholder], comment: str):
super().__init__(comment=comment) super().__init__(comment=comment)
self.exp_type = exp_type self.exp_type = exp_type
self.args = list(args) self.args = list(args)
def _generate_static_array(self, sa_args: tuple[types.Type3, types.IntType3]) -> CheckResult: def _generate_static_array(self, sa_args: tuple[Type3, IntType3]) -> CheckResult:
sa_type, sa_len = sa_args sa_type, sa_len = sa_args
if sa_len.value != len(self.args): if sa_len.value != len(self.args):
@ -221,7 +231,7 @@ class TupleMatchConstraint(ConstraintBase):
for arg in self.args for arg in self.args
] ]
def _generate_tuple(self, tp_args: tuple[types.Type3, ...]) -> CheckResult: def _generate_tuple(self, tp_args: tuple[Type3, ...]) -> CheckResult:
if len(tp_args) != len(self.args): if len(tp_args) != len(self.args):
return Error('Mismatch between applied types argument count', comment=self.comment) return Error('Mismatch between applied types argument count', comment=self.comment)
@ -236,7 +246,7 @@ class TupleMatchConstraint(ConstraintBase):
def check(self) -> CheckResult: def check(self) -> CheckResult:
exp_type = self.exp_type exp_type = self.exp_type
if isinstance(exp_type, placeholders.PlaceholderForType): if isinstance(exp_type, PlaceholderForType):
if exp_type.resolve_as is None: if exp_type.resolve_as is None:
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
@ -254,34 +264,34 @@ class MustImplementTypeClassConstraint(ConstraintBase):
__slots__ = ('context', 'type_class3', 'types', ) __slots__ = ('context', 'type_class3', 'types', )
context: Context context: Context
type_class3: Union[str, typeclasses.Type3Class] type_class3: Union[str, Type3Class]
types: list[placeholders.Type3OrPlaceholder] types: list[Type3OrPlaceholder]
DATA = { DATA = {
'bytes': {'Foldable'}, 'bytes': {'Foldable'},
} }
def __init__(self, context: Context, type_class3: Union[str, typeclasses.Type3Class], types: list[placeholders.Type3OrPlaceholder], comment: Optional[str] = None) -> None: def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
self.context = context self.context = context
self.type_class3 = type_class3 self.type_class3 = type_class3
self.types = types self.types = typ_list
def check(self) -> CheckResult: def check(self) -> CheckResult:
typ_list: list[types.Type3 | types.TypeConstructor_Base[Any] | types.TypeConstructor_Struct] = [] typ_list: list[Type3 | TypeConstructor_Base[Any] | TypeConstructor_Struct] = []
for typ in self.types: for typ in self.types:
if isinstance(typ, placeholders.PlaceholderForType) and typ.resolve_as is not None: if isinstance(typ, PlaceholderForType) and typ.resolve_as is not None:
typ = typ.resolve_as typ = typ.resolve_as
if isinstance(typ, placeholders.PlaceholderForType): if isinstance(typ, PlaceholderForType):
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
if isinstance(typ.application, (types.TypeApplication_Nullary, types.TypeApplication_Struct, )): if isinstance(typ.application, (TypeApplication_Nullary, TypeApplication_Struct, )):
typ_list.append(typ) typ_list.append(typ)
continue continue
if isinstance(typ.application, (types.TypeApplication_TypeInt, types.TypeApplication_TypeStar)): if isinstance(typ.application, (TypeApplication_TypeInt, TypeApplication_TypeStar)):
typ_list.append(typ.application.constructor) typ_list.append(typ.application.constructor)
continue continue
@ -289,7 +299,7 @@ class MustImplementTypeClassConstraint(ConstraintBase):
assert len(typ_list) == len(self.types) assert len(typ_list) == len(self.types)
if isinstance(self.type_class3, typeclasses.Type3Class): if isinstance(self.type_class3, Type3Class):
key = (self.type_class3, tuple(typ_list), ) key = (self.type_class3, tuple(typ_list), )
if key in self.context.type_class_instances_existing: if key in self.context.type_class_instances_existing:
return None return None
@ -324,12 +334,12 @@ class LiteralFitsConstraint(ConstraintBase):
""" """
__slots__ = ('type3', 'literal', ) __slots__ = ('type3', 'literal', )
type3: placeholders.Type3OrPlaceholder type3: Type3OrPlaceholder
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantBytes, ourlang.ConstantTuple, ourlang.ConstantStruct] literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantBytes, ourlang.ConstantTuple, ourlang.ConstantStruct]
def __init__( def __init__(
self, self,
type3: placeholders.Type3OrPlaceholder, type3: Type3OrPlaceholder,
literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantBytes, ourlang.ConstantTuple, ourlang.ConstantStruct], literal: Union[ourlang.ConstantPrimitive, ourlang.ConstantBytes, ourlang.ConstantTuple, ourlang.ConstantStruct],
comment: Optional[str] = None, comment: Optional[str] = None,
) -> None: ) -> None:
@ -338,7 +348,7 @@ class LiteralFitsConstraint(ConstraintBase):
self.type3 = type3 self.type3 = type3
self.literal = literal self.literal = literal
def _generate_static_array(self, sa_args: tuple[types.Type3, types.IntType3]) -> CheckResult: def _generate_static_array(self, sa_args: tuple[Type3, IntType3]) -> CheckResult:
if not isinstance(self.literal, ourlang.ConstantTuple): if not isinstance(self.literal, ourlang.ConstantTuple):
return Error('Must be tuple', comment=self.comment) return Error('Must be tuple', comment=self.comment)
@ -364,7 +374,7 @@ class LiteralFitsConstraint(ConstraintBase):
return res return res
def _generate_struct(self, st_args: tuple[tuple[str, types.Type3], ...]) -> CheckResult: def _generate_struct(self, st_args: tuple[tuple[str, Type3], ...]) -> CheckResult:
if not isinstance(self.literal, ourlang.ConstantStruct): if not isinstance(self.literal, ourlang.ConstantStruct):
return Error('Must be struct') return Error('Must be struct')
@ -382,13 +392,19 @@ class LiteralFitsConstraint(ConstraintBase):
# gets updated when we figure out the type of the # gets updated when we figure out the type of the
# expression the literal is used in # expression the literal is used in
res.extend( res.extend(
SameTypeConstraint(x_t, PlaceholderForType([y]), comment=f'{self.literal.struct_name}.{x_n}') SameTypeConstraint(x_t, PlaceholderForType([y]), comment=f'{self.literal.struct_type3.name}.{x_n}')
for (x_n, x_t, ), y in zip(st_args, self.literal.value, strict=True) for (x_n, x_t, ), y in zip(st_args, self.literal.value, strict=True)
) )
res.append(SameTypeConstraint(
self.literal.struct_type3,
self.type3,
comment='Struct types must match',
))
return res return res
def _generate_tuple(self, tp_args: tuple[types.Type3, ...]) -> CheckResult: def _generate_tuple(self, tp_args: tuple[Type3, ...]) -> CheckResult:
if not isinstance(self.literal, ourlang.ConstantTuple): if not isinstance(self.literal, ourlang.ConstantTuple):
return Error('Must be tuple', comment=self.comment) return Error('Must be tuple', comment=self.comment)
@ -432,7 +448,7 @@ class LiteralFitsConstraint(ConstraintBase):
'f64': None, 'f64': None,
} }
if isinstance(self.type3, placeholders.PlaceholderForType): if isinstance(self.type3, PlaceholderForType):
if self.type3.resolve_as is None: if self.type3.resolve_as is None:
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
@ -490,65 +506,59 @@ class CanBeSubscriptedConstraint(ConstraintBase):
""" """
A value that is subscipted, i.e. a[0] (tuple) or a[b] (static array) A value that is subscipted, i.e. a[0] (tuple) or a[b] (static array)
""" """
__slots__ = ('ret_type3', 'type3', 'index', 'index_phft', ) __slots__ = ('ret_type3', 'type3', 'index_type3', 'index_const', )
ret_type3: placeholders.Type3OrPlaceholder ret_type3: PlaceholderForType
type3: placeholders.Type3OrPlaceholder type3: PlaceholderForType
index: ourlang.Expression index_type3: PlaceholderForType
index_phft: placeholders.Type3OrPlaceholder index_const: int | None
def __init__( def __init__(
self, self,
ret_type3: placeholders.PlaceholderForType, ret_type3: PlaceholderForType,
type3: placeholders.PlaceholderForType, type3: PlaceholderForType,
index: ourlang.Expression, index_type3: PlaceholderForType,
index_phft: placeholders.PlaceholderForType, index_const: int | None,
comment: Optional[str] = None, comment: Optional[str] = None,
) -> None: ) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
self.ret_type3 = ret_type3 self.ret_type3 = ret_type3
self.type3 = type3 self.type3 = type3
self.index = index self.index_type3 = index_type3
self.index_phft = index_phft self.index_const = index_const
def _generate_bytes(self) -> CheckResult: def _generate_bytes(self) -> CheckResult:
return [ return [
SameTypeConstraint(prelude.u32, self.index_phft, comment='([]) :: bytes -> u32 -> u8'), SameTypeConstraint(prelude.u32, self.index_type3, comment='([]) :: bytes -> u32 -> u8'),
SameTypeConstraint(prelude.u8, self.ret_type3, comment='([]) :: bytes -> u32 -> u8'), SameTypeConstraint(prelude.u8, self.ret_type3, comment='([]) :: bytes -> u32 -> u8'),
] ]
def _generate_static_array(self, sa_args: tuple[types.Type3, types.IntType3]) -> CheckResult: def _generate_static_array(self, sa_args: tuple[Type3, IntType3]) -> CheckResult:
sa_type, sa_len = sa_args sa_type, sa_len = sa_args
if isinstance(self.index, ourlang.ConstantPrimitive): if self.index_const is not None and (self.index_const < 0 or sa_len.value <= self.index_const):
assert isinstance(self.index.value, int) return Error('Tuple index out of range')
if self.index.value < 0 or sa_len.value <= self.index.value:
return Error('Tuple index out of range')
return [ return [
SameTypeConstraint(prelude.u32, self.index_phft, comment='([]) :: Subscriptable a => a b -> u32 -> b'), SameTypeConstraint(prelude.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
SameTypeConstraint(sa_type, self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'), SameTypeConstraint(sa_type, self.ret_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
] ]
def _generate_tuple(self, tp_args: tuple[types.Type3, ...]) -> CheckResult: def _generate_tuple(self, tp_args: tuple[Type3, ...]) -> CheckResult:
# We special case tuples to allow for ease of use to the programmer # We special case tuples to allow for ease of use to the programmer
# e.g. rather than having to do `fst a` and `snd a` and only have to-sized tuples # e.g. rather than having to do `fst a` and `snd a` and only have to-sized tuples
# we use a[0] and a[1] and allow for a[2] and on. # we use a[0] and a[1] and allow for a[2] and on.
if not isinstance(self.index, ourlang.ConstantPrimitive): if self.index_const is None:
return Error('Must index with literal')
if not isinstance(self.index.value, int):
return Error('Must index with integer literal') return Error('Must index with integer literal')
if self.index.value < 0 or len(tp_args) <= self.index.value: if self.index_const < 0 or len(tp_args) <= self.index_const:
return Error('Tuple index out of range') return Error('Tuple index out of range')
return [ return [
SameTypeConstraint(prelude.u32, self.index_phft, comment=f'Tuple subscript index {self.index.value}'), SameTypeConstraint(prelude.u32, self.index_type3, comment='([]) :: Subscriptable a => a b -> u32 -> b'),
SameTypeConstraint(tp_args[self.index.value], self.ret_type3, comment=f'Tuple subscript index {self.index.value}'), SameTypeConstraint(tp_args[self.index_const], self.ret_type3, comment=f'Tuple subscript index {self.index_const}'),
] ]
GENERATE_ROUTER = TypeApplicationRouter['CanBeSubscriptedConstraint', CheckResult]() GENERATE_ROUTER = TypeApplicationRouter['CanBeSubscriptedConstraint', CheckResult]()
@ -557,12 +567,10 @@ class CanBeSubscriptedConstraint(ConstraintBase):
GENERATE_ROUTER.add(prelude.tuple_, _generate_tuple) GENERATE_ROUTER.add(prelude.tuple_, _generate_tuple)
def check(self) -> CheckResult: def check(self) -> CheckResult:
exp_type = self.type3 if self.type3.resolve_as is None:
if isinstance(exp_type, placeholders.PlaceholderForType): return RequireTypeSubstitutes()
if exp_type.resolve_as is None:
return RequireTypeSubstitutes()
exp_type = exp_type.resolve_as exp_type = self.type3.resolve_as
try: try:
return self.__class__.GENERATE_ROUTER(self, exp_type) return self.__class__.GENERATE_ROUTER(self, exp_type)
@ -574,9 +582,9 @@ class CanBeSubscriptedConstraint(ConstraintBase):
'{type3}[{index}]', '{type3}[{index}]',
{ {
'type3': self.type3, 'type3': self.type3,
'index': self.index, 'index': self.index_type3 if self.index_const is None else self.index_const,
}, },
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return f'CanBeSubscriptedConstraint({repr(self.type3)}, {repr(self.index)}, comment={repr(self.comment)})' return f'CanBeSubscriptedConstraint({self.ret_type3!r}, {self.type3!r}, {self.index_type3!r}, {self.index_const!r}, comment={repr(self.comment)})'

View File

@ -6,10 +6,6 @@ The constraints solver can then try to resolve all constraints.
from typing import Generator, List from typing import Generator, List
from .. import ourlang, prelude from .. import ourlang, prelude
from . import functions as functions
from . import placeholders as placeholders
from . import typeclasses as typeclasses
from . import types as type3types
from .constraints import ( from .constraints import (
CanBeSubscriptedConstraint, CanBeSubscriptedConstraint,
ConstraintBase, ConstraintBase,
@ -20,7 +16,14 @@ from .constraints import (
SameTypeConstraint, SameTypeConstraint,
TupleMatchConstraint, TupleMatchConstraint,
) )
from .functions import (
Constraint_TypeClassInstanceExists,
FunctionSignature,
TypeVariable,
TypeVariableApplication_Unary,
)
from .placeholders import PlaceholderForType from .placeholders import PlaceholderForType
from .types import Type3, TypeApplication_Struct
ConstraintGenerator = Generator[ConstraintBase, None, None] ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -30,7 +33,7 @@ def phasm_type3_generate_constraints(inp: ourlang.Module) -> List[ConstraintBase
return [*module(ctx, inp)] return [*module(ctx, inp)]
def constant(ctx: Context, inp: ourlang.Constant, phft: placeholders.PlaceholderForType) -> ConstraintGenerator: def constant(ctx: Context, inp: ourlang.Constant, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantBytes, ourlang.ConstantTuple, ourlang.ConstantStruct)): if isinstance(inp, (ourlang.ConstantPrimitive, ourlang.ConstantBytes, ourlang.ConstantTuple, ourlang.ConstantStruct)):
yield LiteralFitsConstraint( yield LiteralFitsConstraint(
phft, inp, phft, inp,
@ -63,7 +66,7 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Plac
def _expression_function_call( def _expression_function_call(
ctx: Context, ctx: Context,
func_name: str, func_name: str,
signature: functions.FunctionSignature, signature: FunctionSignature,
arguments: list[ourlang.Expression], arguments: list[ourlang.Expression],
return_expr: ourlang.Expression, return_expr: ourlang.Expression,
return_phft: PlaceholderForType, return_phft: PlaceholderForType,
@ -92,13 +95,13 @@ def _expression_function_call(
# placeholder here. These don't need to update anything once # placeholder here. These don't need to update anything once
# subsituted - that's done by arg_placeholders. # subsituted - that's done by arg_placeholders.
type_var_map = { type_var_map = {
x: placeholders.PlaceholderForType([]) x: PlaceholderForType([])
for x in signature.args for x in signature.args
if isinstance(x, functions.TypeVariable) if isinstance(x, TypeVariable)
} }
for constraint in signature.context.constraints: for constraint in signature.context.constraints:
if isinstance(constraint, functions.Constraint_TypeClassInstanceExists): if isinstance(constraint, Constraint_TypeClassInstanceExists):
yield MustImplementTypeClassConstraint( yield MustImplementTypeClassConstraint(
ctx, ctx,
constraint.type_class3, constraint.type_class3,
@ -113,7 +116,7 @@ def _expression_function_call(
# That is, given `foo :: t a -> a` we need to ensure # That is, given `foo :: t a -> a` we need to ensure
# that both a's are the same. # that both a's are the same.
for sig_arg in signature.args: for sig_arg in signature.args:
if isinstance(sig_arg, type3types.Type3): if isinstance(sig_arg, Type3):
# Not a type variable at all # Not a type variable at all
continue continue
@ -121,7 +124,7 @@ def _expression_function_call(
# Not a type variable for a type constructor # Not a type variable for a type constructor
continue continue
if not isinstance(sig_arg.application, functions.TypeVariableApplication_Unary): if not isinstance(sig_arg.application, TypeVariableApplication_Unary):
raise NotImplementedError(sig_arg.application) raise NotImplementedError(sig_arg.application)
assert sig_arg.application.arguments in type_var_map # When does this happen? assert sig_arg.application.arguments in type_var_map # When does this happen?
@ -139,18 +142,18 @@ def _expression_function_call(
else: else:
comment = f'The type of the value passed to argument {arg_no} of function {func_name} should match the type of that argument' comment = f'The type of the value passed to argument {arg_no} of function {func_name} should match the type of that argument'
if isinstance(sig_part, functions.TypeVariable): if isinstance(sig_part, TypeVariable):
yield SameTypeConstraint(type_var_map[sig_part], arg_placeholders[arg_expr], comment=comment) yield SameTypeConstraint(type_var_map[sig_part], arg_placeholders[arg_expr], comment=comment)
continue continue
if isinstance(sig_part, type3types.Type3): if isinstance(sig_part, Type3):
yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment) yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment)
continue continue
raise NotImplementedError(sig_part) raise NotImplementedError(sig_part)
return return
def expression(ctx: Context, inp: ourlang.Expression, phft: placeholders.PlaceholderForType) -> ConstraintGenerator: def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp, ourlang.Constant): if isinstance(inp, ourlang.Constant):
yield from constant(ctx, inp, phft) yield from constant(ctx, inp, phft)
return return
@ -190,11 +193,14 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: placeholders.Placeho
yield from expression(ctx, inp.varref, varref_phft) yield from expression(ctx, inp.varref, varref_phft)
yield from expression(ctx, inp.index, index_phft) yield from expression(ctx, inp.index, index_phft)
yield CanBeSubscriptedConstraint(phft, varref_phft, inp.index, index_phft) if isinstance(inp.index, ourlang.ConstantPrimitive) and isinstance(inp.index.value, int):
yield CanBeSubscriptedConstraint(phft, varref_phft, index_phft, inp.index.value)
else:
yield CanBeSubscriptedConstraint(phft, varref_phft, index_phft, None)
return return
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
assert isinstance(inp.struct_type3.application, type3types.TypeApplication_Struct) # FIXME: See test_struct.py::test_struct_not_accessible assert isinstance(inp.struct_type3.application, TypeApplication_Struct) # FIXME: See test_struct.py::test_struct_not_accessible
mem_typ = dict(inp.struct_type3.application.arguments)[inp.member] mem_typ = dict(inp.struct_type3.application.arguments)[inp.member]

View File

@ -64,6 +64,36 @@ def helper(shape1: Rectangle, shape2: Rectangle) -> i32:
assert 545 == result.returned_value assert 545 == result.returned_value
@pytest.mark.integration_test
def test_type_mismatch_struct_call_root():
code_py = """
class CheckedValueBlue:
value: i32
class CheckedValueRed:
value: i32
CONST: CheckedValueBlue = CheckedValueRed(1)
"""
with pytest.raises(Type3Exception, match='CheckedValueBlue must be CheckedValueRed instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_type_mismatch_struct_call_nested():
code_py = """
class CheckedValueBlue:
value: i32
class CheckedValueRed:
value: i32
CONST: (CheckedValueBlue, u32, ) = (CheckedValueRed(1), 16, )
"""
with pytest.raises(Type3Exception, match='CheckedValueBlue must be CheckedValueRed instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64']) @pytest.mark.parametrize('type_', ['i32', 'i64', 'f32', 'f64'])
def test_type_mismatch_struct_member(type_): def test_type_mismatch_struct_member(type_):

View File

@ -64,7 +64,7 @@ def testEntry(x: (u8, u32, u64), y: u8) -> u64:
return x[y] return x[y]
""" """
with pytest.raises(Type3Exception, match='Must index with literal'): with pytest.raises(Type3Exception, match='Must index with integer literal'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test