Typeclasses

This commit is contained in:
Johan B.W. de Vries 2025-08-02 19:58:07 +02:00
parent 50066c3638
commit cb79918099
14 changed files with 320 additions and 189 deletions

View File

@ -16,7 +16,7 @@ from ..type3.routers import (
TypeClassArgsRouter, TypeClassArgsRouter,
TypeVariableLookup, TypeVariableLookup,
) )
from ..type3.typeclasses import Type3Class, Type3ClassMethod from ..typeclass import TypeClass
from ..type3.types import ( from ..type3.types import (
IntType3, IntType3,
Type3, Type3,
@ -33,6 +33,7 @@ from ..type5 import typeexpr as type5typeexpr
from ..wasm import WasmType, WasmTypeInt32, WasmTypeNone from ..wasm import WasmType, WasmTypeInt32, WasmTypeNone
from . import builtins from . import builtins
from .typerouter import TypeAllocSize, TypeName from .typerouter import TypeAllocSize, TypeName
from .typevariablerouter import TypeVariableRouter
TypeInfo = NamedTuple('TypeInfo', [ TypeInfo = NamedTuple('TypeInfo', [
# Name of the type # Name of the type
@ -179,27 +180,27 @@ class BuildBase[G]:
Types that are available without explicit import. Types that are available without explicit import.
""" """
type_classes: dict[str, Type3Class] type_classes: dict[str, TypeClass]
""" """
Type classes that are available without explicit import. Type classes that are available without explicit import.
""" """
type_class_instances: set[tuple[Type3Class, tuple[Type3 | TypeConstructor_Base[Any], ...]]] type_class_instances: dict[str, set[tuple[type5typeexpr.TypeExpr, ...]]]
""" """
Type class instances that are available without explicit import. Type class instances that are available without explicit import.
""" """
type_class_instance_methods: dict[Type3ClassMethod, TypeClassArgsRouter[G, None]] # type_class_instance_methods: dict[tuple[TypeClass, str], TypeClassArgsRouter[G, None]]
""" """
Methods (and operators) for type class instances that are available without explicit import. Methods (and operators) for type class instances that are available without explicit import.
""" """
methods: dict[str, Type3ClassMethod] methods: dict[str, tuple[type5typeexpr.TypeExpr, TypeVariableRouter[G]]]
""" """
Methods that are available without explicit import. Methods that are available without explicit import.
""" """
operators: dict[str, Type3ClassMethod] # operators: dict[str, Type3ClassMethod]
""" """
Operators that are available without explicit import. Operators that are available without explicit import.
""" """
@ -268,7 +269,7 @@ class BuildBase[G]:
'u32': self.u32_type5, 'u32': self.u32_type5,
} }
self.type_classes = {} self.type_classes = {}
self.type_class_instances = set() self.type_class_instances = {}
self.type_class_instance_methods = {} self.type_class_instance_methods = {}
self.methods = {} self.methods = {}
self.operators = {} self.operators = {}
@ -282,83 +283,112 @@ class BuildBase[G]:
self.type5_alloc_size_root = TypeAllocSize(self, is_member=False) self.type5_alloc_size_root = TypeAllocSize(self, is_member=False)
self.type5_alloc_size_member = TypeAllocSize(self, is_member=True) self.type5_alloc_size_member = TypeAllocSize(self, is_member=True)
def register_type_class(self, cls: Type3Class) -> None: # def register_type_class(self, cls: Type3Class) -> None:
""" # """
Register that the given type class exists # Register that the given type class exists
""" # """
old_len_methods = len(self.methods) # old_len_methods = len(self.methods)
old_len_operators = len(self.operators) # old_len_operators = len(self.operators)
# self.type_classes[cls.name] = cls
# self.methods.update(cls.methods)
# self.operators.update(cls.operators)
# assert len(self.methods) == old_len_methods + len(cls.methods), 'Duplicated method detected'
# assert len(self.operators) == old_len_operators + len(cls.operators), 'Duplicated operator detected'
# def instance_type_class(
# self,
# cls: Type3Class,
# *typ: Type3 | TypeConstructor_Base[Any],
# methods: dict[str, Callable[[G, TypeVariableLookup], None]] = {},
# operators: dict[str, Callable[[G, TypeVariableLookup], None]] = {},
# ) -> None:
# """
# Registered the given type class and its implementation
# """
# assert len(cls.args) == len(typ)
# for incls in cls.inherited_classes:
# if (incls, tuple(typ), ) not in self.type_class_instances:
# warn(MissingImplementationWarning(
# incls.name + ' ' + ' '.join(x.name for x in typ) + ' - required for ' + cls.name
# ))
# # First just register the type
# self.type_class_instances.add((cls, tuple(typ), ))
# # Then make the implementation findable
# # We route based on the type class arguments.
# tv_map: dict[TypeVariable, Type3] = {}
# tc_map: dict[TypeConstructorVariable, TypeConstructor_Base[Any]] = {}
# for arg_tv, arg_tp in zip(cls.args, typ, strict=True):
# if isinstance(arg_tv, TypeVariable):
# assert isinstance(arg_tp, Type3)
# tv_map[arg_tv] = arg_tp
# elif isinstance(arg_tv, TypeConstructorVariable):
# assert isinstance(arg_tp, TypeConstructor_Base)
# tc_map[arg_tv] = arg_tp
# else:
# raise NotImplementedError(arg_tv, arg_tp)
# for method_name, method in cls.methods.items():
# router = self.type_class_instance_methods.get(method)
# if router is None:
# router = TypeClassArgsRouter[G, None](cls.args)
# self.type_class_instance_methods[method] = router
# try:
# generator = methods[method_name]
# except KeyError:
# warn(MissingImplementationWarning(str(method), cls.name + ' ' + ' '.join(x.name for x in typ)))
# continue
# router.add(tv_map, tc_map, generator)
# for operator_name, operator in cls.operators.items():
# router = self.type_class_instance_methods.get(operator)
# if router is None:
# router = TypeClassArgsRouter[G, None](cls.args)
# self.type_class_instance_methods[operator] = router
# try:
# generator = operators[operator_name]
# except KeyError:
# warn(MissingImplementationWarning(str(operator), cls.name + ' ' + ' '.join(x.name for x in typ)))
# continue
# router.add(tv_map, tc_map, generator)
def register_type_class(self, cls: TypeClass) -> None:
assert cls.name not in self.type_classes, 'Duplicate typeclass name'
self.type_classes[cls.name] = cls self.type_classes[cls.name] = cls
self.methods.update(cls.methods) self.type_class_instances[cls.name] = set()
self.operators.update(cls.operators)
assert len(self.methods) == old_len_methods + len(cls.methods), 'Duplicated method detected' def register_type_class_method(
assert len(self.operators) == old_len_operators + len(cls.operators), 'Duplicated operator detected'
def instance_type_class(
self, self,
cls: Type3Class, cls: TypeClass,
*typ: Type3 | TypeConstructor_Base[Any], name: str,
methods: dict[str, Callable[[G, TypeVariableLookup], None]] = {}, type: type5typeexpr.TypeExpr,
operators: dict[str, Callable[[G, TypeVariableLookup], None]] = {},
) -> None: ) -> None:
""" assert name not in self.methods, 'Duplicate typeclass method name'
Registered the given type class and its implementation
"""
assert len(cls.args) == len(typ)
for incls in cls.inherited_classes: self.methods[name] = (type, TypeVariableRouter(), )
if (incls, tuple(typ), ) not in self.type_class_instances:
warn(MissingImplementationWarning(
incls.name + ' ' + ' '.join(x.name for x in typ) + ' - required for ' + cls.name
))
# First just register the type def register_type_class_instance(
self.type_class_instances.add((cls, tuple(typ), )) self,
cls: TypeClass,
*args: type5typeexpr.TypeExpr,
methods: dict[str, Callable[[G, Any], None]],
) -> None:
self.type_class_instances[cls.name].add(tuple(args))
# Then make the implementation findable assert len(cls.variables) == len(args)
# We route based on the type class arguments.
tv_map: dict[TypeVariable, Type3] = {}
tc_map: dict[TypeConstructorVariable, TypeConstructor_Base[Any]] = {}
for arg_tv, arg_tp in zip(cls.args, typ, strict=True):
if isinstance(arg_tv, TypeVariable):
assert isinstance(arg_tp, Type3)
tv_map[arg_tv] = arg_tp
elif isinstance(arg_tv, TypeConstructorVariable):
assert isinstance(arg_tp, TypeConstructor_Base)
tc_map[arg_tv] = arg_tp
else:
raise NotImplementedError(arg_tv, arg_tp)
for method_name, method in cls.methods.items():
router = self.type_class_instance_methods.get(method)
if router is None:
router = TypeClassArgsRouter[G, None](cls.args)
self.type_class_instance_methods[method] = router
try:
generator = methods[method_name]
except KeyError:
warn(MissingImplementationWarning(str(method), cls.name + ' ' + ' '.join(x.name for x in typ)))
continue
router.add(tv_map, tc_map, generator)
for operator_name, operator in cls.operators.items():
router = self.type_class_instance_methods.get(operator)
if router is None:
router = TypeClassArgsRouter[G, None](cls.args)
self.type_class_instance_methods[operator] = router
try:
generator = operators[operator_name]
except KeyError:
warn(MissingImplementationWarning(str(operator), cls.name + ' ' + ' '.join(x.name for x in typ)))
continue
router.add(tv_map, tc_map, generator)
for mtd_nam, mtd_imp in methods.items():
_, mtd_rtr = self.methods[mtd_nam]
mtd_rtr.register(cls.variables, args, mtd_imp)
def calculate_alloc_size_static_array(self, args: tuple[Type3, IntType3]) -> int: def calculate_alloc_size_static_array(self, args: tuple[Type3, IntType3]) -> int:
""" """

View File

@ -98,14 +98,15 @@ class BuildDefault(BuildBase[Generator]):
}) })
tc_list = [ tc_list = [
bits, floating,
eq, ord, # bits,
extendable, promotable, # eq, ord,
convertable, reinterpretable, # extendable, promotable,
natnum, intnum, fractional, floating, # convertable, reinterpretable,
integral, # natnum, intnum, fractional, floating,
foldable, subscriptable, # integral,
sized, # foldable, subscriptable,
# sized,
] ]
for tc in tc_list: for tc in tc_list:

View File

@ -3,24 +3,30 @@ The Floating type class is defined for Real numbers.
""" """
from typing import Any from typing import Any
from ...type3.functions import make_typevar from ...type5.kindexpr import Star
from ...type5.typeexpr import TypeVariable
from ...type3.routers import TypeVariableLookup from ...type3.routers import TypeVariableLookup
from ...type3.typeclasses import Type3Class from ...typeclass import TypeClass
from ...wasmgenerator import Generator as WasmGenerator from ...wasmgenerator import Generator as WasmGenerator
from ..base import BuildBase from ..base import BuildBase
def load(build: BuildBase[Any]) -> None: def load(build: BuildBase[Any]) -> None:
a = make_typevar('a') a = TypeVariable(kind=Star(), name='a')
Fractional = build.type_classes['Fractional']
Floating = Type3Class('Floating', (a, ), methods={
'sqrt': [a, a],
}, operators={}, inherited_classes=[Fractional])
# FIXME: Do we want to expose copysign?
Floating = TypeClass('Floating', [a])
build.register_type_class(Floating) build.register_type_class(Floating)
build.register_type_class_method(Floating, 'sqrt', build.type5_make_function([a, a]))
# a = make_typevar('a')
# # Fractional = build.type_classes['Fractional'] # TODO
# Floating = Type3Class('Floating', (a, ), methods={
# 'sqrt': [a, a],
# }, operators={}, inherited_classes=[Fractional])
# # FIXME: Do we want to expose copysign?
def wasm_f32_sqrt(g: WasmGenerator, tv_map: TypeVariableLookup) -> None: def wasm_f32_sqrt(g: WasmGenerator, tv_map: TypeVariableLookup) -> None:
del tv_map del tv_map
g.add_statement('f32.sqrt') g.add_statement('f32.sqrt')
@ -32,9 +38,9 @@ def wasm_f64_sqrt(g: WasmGenerator, tv_map: TypeVariableLookup) -> None:
def wasm(build: BuildBase[WasmGenerator]) -> None: def wasm(build: BuildBase[WasmGenerator]) -> None:
Floating = build.type_classes['Floating'] Floating = build.type_classes['Floating']
build.instance_type_class(Floating, build.types['f32'], methods={ build.register_type_class_instance(Floating, build.type5s['f32'], methods={
'sqrt': wasm_f32_sqrt, 'sqrt': wasm_f32_sqrt,
}) })
build.instance_type_class(Floating, build.types['f64'], methods={ build.register_type_class_instance(Floating, build.type5s['f64'], methods={
'sqrt': wasm_f64_sqrt, 'sqrt': wasm_f64_sqrt,
}) })

View File

@ -0,0 +1,41 @@
from typing import Any, Callable, Iterable, TypeAlias
from ..type5 import typeexpr as type5typeexpr
class TypeVariableRouter[G]:
__slots__ = ('data', )
data: dict[
tuple[tuple[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr], ...],
Callable[[G, dict[type5typeexpr.TypeVariable, type5typeexpr.TypeExpr]], None],
]
def __init__(self) -> None:
self.data = {}
def register(
self,
variables: Iterable[type5typeexpr.TypeVariable],
types: Iterable[type5typeexpr.TypeExpr],
implementation: Callable[[G, Any], None],
) -> None:
variables = list(variables)
types = list(types)
assert len(variables) == len(set(variables))
key = tuple(sorted(tuple(zip(variables, types))))
print('key', key)
self.data[key] = implementation
def __call__(
self,
variables: Iterable[type5typeexpr.TypeVariable],
types: Iterable[type5typeexpr.TypeExpr],
) -> Callable[[G, Any], None]:
variables = list(variables)
types = list(types)
key = tuple(sorted(tuple(zip(variables, types))))
print('key', key)
return self.data[key]

View File

@ -84,10 +84,10 @@ def expression(inp: ourlang.Expression) -> str:
for arg in inp.arguments for arg in inp.arguments
) )
if isinstance(inp.function, ourlang.StructConstructor): if isinstance(inp.function_instance.function, ourlang.StructConstructor):
return f'{inp.function.struct_type3.name}({args})' return f'{inp.function.struct_type3.name}({args})'
return f'{inp.function.name}({args})' return f'{inp.function_instance.function.name}({args})'
if isinstance(inp, ourlang.FunctionReference): if isinstance(inp, ourlang.FunctionReference):
return str(inp.function.name) return str(inp.function.name)

View File

@ -15,10 +15,9 @@ from .type3.routers import NoRouteForTypeException
from .type3.typeclasses import Type3ClassMethod from .type3.typeclasses import Type3ClassMethod
from .type3.types import ( from .type3.types import (
Type3, Type3,
TypeApplication_Struct,
TypeConstructor_Function,
) )
from .type5.typeexpr import TypeExpr, is_concrete from .type5.typeexpr import TypeExpr, is_concrete
from .type5.unify import ReplaceVariable, unify
from .wasm import ( from .wasm import (
WasmTypeFloat32, WasmTypeFloat32,
WasmTypeFloat64, WasmTypeFloat64,
@ -156,6 +155,45 @@ def expression_subscript_tuple(wgn: WasmGenerator, mod: ourlang.Module[WasmGener
expression(wgn, mod, inp.varref) expression(wgn, mod, inp.varref)
wgn.add_statement(el_type_info.wasm_load_func, f'offset={offset}') wgn.add_statement(el_type_info.wasm_load_func, f'offset={offset}')
def expression_function_call(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.FunctionCall) -> None:
for arg in inp.arguments:
expression(wgn, mod, arg)
if isinstance(inp.function_instance.function, ourlang.BuiltinFunction):
assert _is_concrete(inp.function_instance.type5), TYPE5_ASSERTION_ERROR
method_type, method_router = mod.build.methods[inp.function_instance.function.name]
instance_type = inp.function_instance.type5
actions = unify(method_type, instance_type)
tv_map = {}
for action in actions:
if isinstance(action, ReplaceVariable):
tv_map[action.var] = action.typ
continue
raise NotImplementedError
method_router(tv_map.keys(), tv_map.values())(wgn, None)
return
if isinstance(inp.function, ourlang.FunctionParam):
fn_args = mod.build.type5_is_function(inp.function.type5)
assert fn_args is not None
params = [
type5(mod, x)
for x in fn_args
]
result = params.pop()
wgn.add_statement('local.get', '${}'.format(inp.function.name))
wgn.call_indirect(params=params, result=result)
return
wgn.call(inp.function.name)
def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Expression) -> None: def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourlang.Expression) -> None:
""" """
Compile: Any expression Compile: Any expression
@ -244,53 +282,7 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourl
return return
if isinstance(inp, ourlang.FunctionCall): if isinstance(inp, ourlang.FunctionCall):
for arg in inp.arguments: expression_function_call(wgn, mod, inp)
expression(wgn, mod, arg)
if isinstance(inp.function, Type3ClassMethod):
# FIXME: Duplicate code with BinaryOp
type_var_map = {}
for type_var, arg_expr in zip(inp.function.signature.args, inp.arguments + [inp], strict=True):
assert arg_expr.type3 is not None, TYPE3_ASSERTION_ERROR
if isinstance(type_var, Type3):
# Fixed type, not part of the lookup requirements
continue
if isinstance(type_var, TypeVariable):
type_var_map[type_var] = arg_expr.type3
continue
if isinstance(type_var, FunctionArgument):
# Fixed type, not part of the lookup requirements
continue
raise NotImplementedError(type_var, arg_expr.type3)
router = mod.build.type_class_instance_methods[inp.function]
try:
router(wgn, type_var_map)
except NoRouteForTypeException:
raise NotImplementedError(str(inp.function), type_var_map)
return
if isinstance(inp.function, ourlang.FunctionParam):
fn_args = mod.build.type5_is_function(inp.function.type5)
assert fn_args is not None
params = [
type5(mod, x)
for x in fn_args
]
result = params.pop()
wgn.add_statement('local.get', '${}'.format(inp.function.name))
wgn.call_indirect(params=params, result=result)
return
wgn.call(inp.function.name)
return return
if isinstance(inp, ourlang.FunctionReference): if isinstance(inp, ourlang.FunctionReference):
@ -314,6 +306,8 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], inp: ourl
expression_subscript_tuple(wgn, mod, inp) expression_subscript_tuple(wgn, mod, inp)
return return
raise NotImplementedError
inp_as_fc = ourlang.FunctionCall(mod.build.type_classes['Subscriptable'].operators['[]'], inp.sourceref) inp_as_fc = ourlang.FunctionCall(mod.build.type_classes['Subscriptable'].operators['[]'], inp.sourceref)
inp_as_fc.type3 = inp.type3 inp_as_fc.type3 = inp.type3
inp_as_fc.type5 = inp.type5 inp_as_fc.type5 = inp.type5
@ -349,7 +343,7 @@ def statement_return(wgn: WasmGenerator, mod: ourlang.Module[WasmGenerator], fun
# Support tail calls # Support tail calls
# https://github.com/WebAssembly/tail-call # https://github.com/WebAssembly/tail-call
# These help a lot with some functional programming techniques # These help a lot with some functional programming techniques
if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function is fun: if isinstance(inp.value, ourlang.FunctionCall) and inp.value.function_instance.function.name is fun:
for arg in inp.value.arguments: for arg in inp.value.arguments:
expression(wgn, mod, arg) expression(wgn, mod, arg)

View File

@ -1,11 +1,12 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
from __future__ import annotations
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from .build.base import BuildBase from .build.base import BuildBase
from .type3.functions import FunctionSignature, TypeVariableContext from .type3.functions import FunctionSignature, TypeVariableContext
from .type3.typeclasses import Type3ClassMethod
from .type3.types import Type3, TypeApplication_Struct from .type3.types import Type3, TypeApplication_Struct
from .type5 import record as type5record from .type5 import record as type5record
from .type5 import typeexpr as type5typeexpr from .type5 import typeexpr as type5typeexpr
@ -164,11 +165,11 @@ class BinaryOp(Expression):
""" """
__slots__ = ('operator', 'left', 'right', ) __slots__ = ('operator', 'left', 'right', )
operator: Type3ClassMethod operator: FunctionInstance
left: Expression left: Expression
right: Expression right: Expression
def __init__(self, operator: Type3ClassMethod, left: Expression, right: Expression, sourceref: SourceRef) -> None: def __init__(self, operator: FunctionInstance, left: Expression, right: Expression, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref) super().__init__(sourceref=sourceref)
self.operator = operator self.operator = operator
@ -178,19 +179,33 @@ class BinaryOp(Expression):
def __repr__(self) -> str: def __repr__(self) -> str:
return f'BinaryOp({repr(self.operator)}, {repr(self.left)}, {repr(self.right)})' return f'BinaryOp({repr(self.operator)}, {repr(self.left)}, {repr(self.right)})'
class FunctionInstance(Expression):
"""
When calling a polymorphic function with concrete arguments, we can generate
code for that specific instance of the function.
"""
__slots__ = ('function', )
function: Union['Function', 'FunctionParam']
def __init__(self, function: Union['Function', 'FunctionParam'], sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref)
self.function = function
class FunctionCall(Expression): class FunctionCall(Expression):
""" """
A function call expression within a statement A function call expression within a statement
""" """
__slots__ = ('function', 'arguments', ) __slots__ = ('function_instance', 'arguments', )
function: Union['Function', 'FunctionParam', Type3ClassMethod] function_instance: FunctionInstance
arguments: List[Expression] arguments: List[Expression]
def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod], sourceref: SourceRef) -> None: def __init__(self, function_instance: FunctionInstance, sourceref: SourceRef) -> None:
super().__init__(sourceref=sourceref) super().__init__(sourceref=sourceref)
self.function = function self.function_instance = function_instance
self.arguments = [] self.arguments = []
class FunctionReference(Expression): class FunctionReference(Expression):
@ -349,6 +364,11 @@ class Function:
self.posonlyargs = [] self.posonlyargs = []
self.arg_names = [] self.arg_names = []
class BuiltinFunction(Function):
def __init__(self, name: str, type5: type5typeexpr.TypeExpr) -> None:
super().__init__(name, SourceRef("/", 0, 0), None)
self.type5 = type5
class StructDefinition: class StructDefinition:
""" """
The definition for a struct The definition for a struct
@ -453,8 +473,8 @@ class Module[G]:
struct_definitions: Dict[str, StructDefinition] struct_definitions: Dict[str, StructDefinition]
constant_defs: Dict[str, ModuleConstantDef] constant_defs: Dict[str, ModuleConstantDef]
functions: Dict[str, Function] functions: Dict[str, Function]
methods: Dict[str, Type3ClassMethod] methods: Dict[str, type5typeexpr.TypeExpr]
operators: Dict[str, Type3ClassMethod] operators: Dict[str, type5typeexpr.TypeExpr]
functions_table: dict[Function, int] functions_table: dict[Function, int]
def __init__(self, build: BuildBase[G], filename: str) -> None: def __init__(self, build: BuildBase[G], filename: str) -> None:

View File

@ -10,6 +10,7 @@ from .exceptions import StaticError
from .ourlang import ( from .ourlang import (
AccessStructMember, AccessStructMember,
BinaryOp, BinaryOp,
BuiltinFunction,
ConstantBytes, ConstantBytes,
ConstantPrimitive, ConstantPrimitive,
ConstantStruct, ConstantStruct,
@ -17,6 +18,7 @@ from .ourlang import (
Expression, Expression,
Function, Function,
FunctionCall, FunctionCall,
FunctionInstance,
FunctionParam, FunctionParam,
FunctionReference, FunctionReference,
Module, Module,
@ -100,7 +102,7 @@ class OurVisitor[G]:
def visit_Module(self, node: ast.Module) -> Module[G]: def visit_Module(self, node: ast.Module) -> Module[G]:
module = Module(self.build, "-") module = Module(self.build, "-")
module.methods.update(self.build.methods) module.methods.update({k: v[0] for k, v in self.build.methods.items()})
module.operators.update(self.build.operators) module.operators.update(self.build.operators)
module.types.update(self.build.types) module.types.update(self.build.types)
module.type5s.update(self.build.type5s) module.type5s.update(self.build.type5s)
@ -415,7 +417,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'Operator {operator}') raise NotImplementedError(f'Operator {operator}')
return BinaryOp( return BinaryOp(
module.operators[operator], FunctionInstance(module.operators[operator], srf(module, node)),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.right),
srf(module, node), srf(module, node),
@ -444,7 +446,7 @@ class OurVisitor[G]:
raise NotImplementedError(f'Operator {operator}') raise NotImplementedError(f'Operator {operator}')
return BinaryOp( return BinaryOp(
module.operators[operator], FunctionInstance(module.operators[operator], srf(module, node)),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.left),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.comparators[0]),
srf(module, node), srf(module, node),
@ -512,7 +514,7 @@ class OurVisitor[G]:
func: Union[Function, FunctionParam, Type3ClassMethod] func: Union[Function, FunctionParam, Type3ClassMethod]
if node.func.id in module.methods: if node.func.id in module.methods:
func = module.methods[node.func.id] func = BuiltinFunction(node.func.id, module.methods[node.func.id])
elif node.func.id in our_locals: elif node.func.id in our_locals:
func = our_locals[node.func.id] func = our_locals[node.func.id]
else: else:
@ -521,7 +523,7 @@ class OurVisitor[G]:
func = module.functions[node.func.id] func = module.functions[node.func.id]
result = FunctionCall(func, sourceref=srf(module, node)) result = FunctionCall(FunctionInstance(func, srf(module, node)), sourceref=srf(module, node))
result.arguments.extend( result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr in node.args for arg_expr in node.args

View File

@ -39,9 +39,10 @@ class Context:
self.build = build self.build = build
self.placeholder_update = {} self.placeholder_update = {}
def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star()) -> TypeVariable: def make_placeholder(self, arg: ExpressionProtocol | None = None, kind: KindExpr = Star(), prefix: str = 'p') -> TypeVariable:
res = TypeVariable(kind, f"p_{len(self.placeholder_update)}") res = TypeVariable(kind, f"{prefix}_{len(self.placeholder_update)}")
self.placeholder_update[res] = arg self.placeholder_update[res] = arg
print('placeholder_update', res, arg)
return res return res
@dataclasses.dataclass @dataclasses.dataclass
@ -440,18 +441,9 @@ class TypeClassInstanceExistsConstraint(ConstraintBase):
if len(c_arg_list) != len(self.arg_list): if len(c_arg_list) != len(self.arg_list):
return skip_for_now() return skip_for_now()
tcls = self.ctx.build.type_classes[self.typeclass] key = tuple(c_arg_list)
existing_instances = self.ctx.build.type_class_instances[self.typeclass]
# Temporary hack while we are converting from type3 to type5 if key in existing_instances:
try:
targs = tuple(
_type5_to_type3_or_type3_const(self.ctx.build, x)
for x in self.arg_list
)
except RecordFoundException:
return fail('Missing type class instance')
if (tcls, targs, ) in self.ctx.build.type_class_instances:
return ok() return ok()
return fail('Missing type class instance') return fail('Missing type class instance')

View File

@ -17,7 +17,7 @@ from .constraints import (
UnifyTypesConstraint, UnifyTypesConstraint,
) )
from .kindexpr import Star from .kindexpr import Star
from .typeexpr import TypeApplication, TypeExpr, TypeVariable from .typeexpr import TypeApplication, TypeExpr, TypeVariable, instantiate
ConstraintGenerator = Generator[ConstraintBase, None, None] ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -107,17 +107,15 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Type
yield from expression(ctx, arg, arg_tv) yield from expression(ctx, arg, arg_tv)
arg_typ_list.append(arg_tv) arg_typ_list.append(arg_tv)
if isinstance(inp.function, type3classes.Type3ClassMethod): assert isinstance(inp.function_instance.function.type5, TypeExpr)
func_type, constraints = _signature_to_type5(ctx, inp.sourceref, inp.function.signature) inp.function_instance.type5 = ctx.make_placeholder(inp.function_instance)
else: yield UnifyTypesConstraint(ctx, inp.sourceref, instantiate(inp.function_instance.function.type5, {}, lambda x, p: ctx.make_placeholder(kind=x, prefix=p)), inp.function_instance.type5)
assert isinstance(inp.function.type5, TypeExpr) # constraints = []
func_type = inp.function.type5
constraints = []
expr_type = ctx.build.type5_make_function(arg_typ_list + [phft]) expr_type = ctx.build.type5_make_function(arg_typ_list + [phft])
yield UnifyTypesConstraint(ctx, inp.sourceref, func_type, expr_type) yield UnifyTypesConstraint(ctx, inp.sourceref, inp.function_instance.type5, expr_type)
yield from constraints # yield from constraints
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: TypeVariable) -> ConstraintGenerator: def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: TypeVariable) -> ConstraintGenerator:
assert inp.function.type5 is not None # Todo: Make not nullable assert inp.function.type5 is not None # Todo: Make not nullable

View File

@ -3,7 +3,7 @@ from typing import Any
from ..ourlang import Module from ..ourlang import Module
from .constraints import ConstraintBase, Context from .constraints import ConstraintBase, Context
from .fromast import phasm_type5_generate_constraints from .fromast import phasm_type5_generate_constraints
from .typeexpr import TypeExpr, TypeVariable from .typeexpr import TypeExpr, TypeVariable, replace_variable
from .unify import ReplaceVariable from .unify import ReplaceVariable
MAX_RESTACK_COUNT = 10 # 100 MAX_RESTACK_COUNT = 10 # 100
@ -47,10 +47,12 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
if isinstance(action, ReplaceVariable): if isinstance(action, ReplaceVariable):
action_var: TypeExpr = action.var action_var: TypeExpr = action.var
while isinstance(action_var, TypeVariable) and action_var in placeholder_types: while isinstance(action_var, TypeVariable) and action_var in placeholder_types:
# TODO: Does this still happen?
action_var = placeholder_types[action_var] action_var = placeholder_types[action_var]
action_typ: TypeExpr = action.typ action_typ: TypeExpr = action.typ
while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types: while isinstance(action_typ, TypeVariable) and action_typ in placeholder_types:
# TODO: Does this still happen?
action_typ = placeholder_types[action_typ] action_typ = placeholder_types[action_typ]
# print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ)) # print(inp.build.type5_name(action_var), ':=', inp.build.type5_name(action_typ))
@ -62,6 +64,11 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
action_typ, action_var = action_var, action_typ action_typ, action_var = action_var, action_typ
if isinstance(action_var, TypeVariable): if isinstance(action_var, TypeVariable):
# Ensure all existing found types are updated
placeholder_types = {
k: replace_variable(v, action_var, action_typ)
for k, v in placeholder_types.items()
}
placeholder_types[action_var] = action_typ placeholder_types[action_var] = action_typ
for oth_const in new_constraint_list + constraint_list: for oth_const in new_constraint_list + constraint_list:
@ -118,4 +125,6 @@ def phasm_type5(inp: Module[Any], verbose: bool = False) -> None:
while isinstance(new_type5, TypeVariable): while isinstance(new_type5, TypeVariable):
new_type5 = placeholder_types[new_type5] new_type5 = placeholder_types[new_type5]
print('expression', expression)
print('new_type5', new_type5)
expression.type5 = new_type5 expression.type5 = new_type5

View File

@ -19,6 +19,9 @@ class AtomicType(TypeExpr):
def __init__(self, name: str) -> None: def __init__(self, name: str) -> None:
super().__init__(Star(), name) super().__init__(Star(), name)
def __hash__(self) -> int:
return hash((self.kind, self.name))
@dataclass @dataclass
class TypeLevelNat(TypeExpr): class TypeLevelNat(TypeExpr):
value: int value: int
@ -42,8 +45,6 @@ class TypeVariable(TypeExpr):
@dataclass @dataclass
class TypeConstructor(TypeExpr): class TypeConstructor(TypeExpr):
name: str
def __init__(self, kind: Arrow, name: str) -> None: def __init__(self, kind: Arrow, name: str) -> None:
super().__init__(kind, name) super().__init__(kind, name)
@ -145,3 +146,33 @@ def replace_variable(expr: TypeExpr, var: TypeVariable, rep_expr: TypeExpr) -> T
) )
raise NotImplementedError raise NotImplementedError
def instantiate(
expr: TypeExpr,
known_map: dict[TypeVariable, TypeVariable],
make_variable: Callable[[KindExpr, str], TypeVariable],
) -> TypeExpr:
if isinstance(expr, AtomicType):
return expr
if isinstance(expr, TypeLevelNat):
return expr
if isinstance(expr, TypeVariable):
known_map.setdefault(expr, make_variable(expr.kind, expr.name))
return known_map[expr]
if isinstance(expr, TypeConstructor):
return expr
if isinstance(expr, TypeApplication):
new_constructor = instantiate(expr.constructor, known_map, make_variable)
assert isinstance(new_constructor, TypeConstructor | TypeApplication | TypeVariable) # type hint
return TypeApplication(
constructor=new_constructor,
argument=instantiate(expr.argument, known_map, make_variable),
)
raise NotImplementedError(expr)

View File

@ -0,0 +1,8 @@
from dataclasses import dataclass
from ..type5.typeexpr import TypeVariable
@dataclass
class TypeClass:
name: str
variables: list[TypeVariable]

View File

@ -41,7 +41,6 @@ class RunnerBase:
""" """
self.phasm_ast = phasm_parse(self.phasm_code) self.phasm_ast = phasm_parse(self.phasm_code)
phasm_type5(self.phasm_ast, verbose=verbose) phasm_type5(self.phasm_ast, verbose=verbose)
phasm_type3(self.phasm_ast, verbose=False)
def compile_ast(self) -> None: def compile_ast(self) -> None:
""" """