Implements sum for Foldable types

Foldable take a TypeConstructor. The first argument must be a
NatNum.
This commit is contained in:
Johan B.W. de Vries 2025-05-05 14:09:38 +02:00
parent 6c627bca01
commit 22a10d5d92
9 changed files with 338 additions and 84 deletions

View File

@ -12,6 +12,7 @@
- Also, check the codes for FIXME and TODO
- Allocation is done using pointers for members, is this desired?
- See if we want to replace Fractional with Real, and add Rational, Irrationl, Algebraic, Transendental
- Implement q32? q64? Two i32/i64 divided?
- Does Subscript do what we want? It's a language feature rather a normal typed thing. How would you implement your own Subscript-able type?
- Clean up Subscript implementation - it's half implemented in the compiler. Makes more sense to move more parts to stdlib_types.
- Have a set of rules or guidelines for the constraint comments, they're messy.

View File

@ -2,7 +2,7 @@
This module contains the code to convert parsed Ourlang into WebAssembly code
"""
import struct
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Union
from . import codestyle, ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset
@ -292,7 +292,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
expression(wgn, inp.left)
expression(wgn, inp.right)
type_var_map: Dict[type3functions.TypeVariable, type3types.Type3] = {}
type_var_map: Dict[Union[type3functions.TypeVariable, type3functions.TypeConstructorVariable], type3types.Type3] = {}
for type_var, arg_expr in zip(inp.operator.signature.args, [inp.left, inp.right, inp], strict=True):
if not isinstance(type_var, type3functions.TypeVariable):
@ -315,12 +315,16 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
type_var_map = {}
for type_var, arg_expr in zip(inp.function.signature.args, inp.arguments + [inp], strict=True):
if not isinstance(type_var, type3functions.TypeVariable):
if isinstance(type_var, type3types.Type3):
# Fixed type, not part of the lookup requirements
continue
if isinstance(type_var, (type3functions.TypeVariable, type3functions.TypeConstructorVariable, )):
assert arg_expr.type3 is not None, TYPE3_ASSERTION_ERROR
type_var_map[type_var] = arg_expr.type3
continue
raise NotImplementedError
router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.function]
try:
@ -720,6 +724,7 @@ def module(inp: ourlang.Module) -> wasm.Module:
stdlib_types.__u32_pow2__,
stdlib_types.__u8_rotl__,
stdlib_types.__u8_rotr__,
stdlib_types.__sa_i32_sum__,
] + [
function(x)
for x in inp.functions.values()

View File

@ -1,25 +1,30 @@
"""
The prelude are all the builtin types, type classes and methods
"""
from typing import Callable
from typing import Any, Callable
from warnings import warn
from phasm.stdlib import types as stdtypes
from phasm.wasmgenerator import Generator
from ..type3.functions import TypeVariable
from ..type3.functions import (
Constraint_TypeClassInstanceExists,
TypeConstructorVariable,
TypeVariable,
)
from ..type3.routers import FunctionSignatureRouter
from ..type3.typeclasses import Type3Class, Type3ClassMethod
from ..type3.types import (
IntType3,
Type3,
TypeApplication_Nullary,
TypeConstructor_Base,
TypeConstructor_StaticArray,
TypeConstructor_Struct,
TypeConstructor_Tuple,
)
PRELUDE_TYPE_CLASS_INSTANCES_EXISTING: set[tuple[Type3Class, tuple[Type3, ...]]] = set()
PRELUDE_TYPE_CLASS_INSTANCES_EXISTING: set[tuple[Type3Class, tuple[Type3 | TypeConstructor_Base[Any], ...]]] = set()
PRELUDE_TYPE_CLASS_INSTANCE_METHODS: dict[Type3ClassMethod, FunctionSignatureRouter[Generator, None]] = {}
@ -198,8 +203,9 @@ PRELUDE_TYPES: dict[str, Type3] = {
a = TypeVariable('a')
b = TypeVariable('b')
t = TypeConstructorVariable('t')
InternalPassAsPointer = Type3Class('InternalPassAsPointer', [a], methods={}, operators={})
InternalPassAsPointer = Type3Class('InternalPassAsPointer', (a, ), methods={}, operators={})
"""
Internal type class to keep track which types we pass arounds as a pointer.
"""
@ -209,7 +215,7 @@ instance_type_class(InternalPassAsPointer, bytes_)
# instance_type_class(InternalPassAsPointer, tuple_)
# instance_type_class(InternalPassAsPointer, struct)
Eq = Type3Class('Eq', [a], methods={}, operators={
Eq = Type3Class('Eq', (a, ), methods={}, operators={
'==': [a, a, bool_],
'!=': [a, a, bool_],
# FIXME: Do we want to expose 'eqz'? Or is that a compiler optimization?
@ -248,7 +254,7 @@ instance_type_class(Eq, f64, operators={
'!=': stdtypes.f64_eq_not_equals,
})
Ord = Type3Class('Ord', [a], methods={
Ord = Type3Class('Ord', (a, ), methods={
'min': [a, a, a],
'max': [a, a, a],
}, operators={
@ -331,7 +337,7 @@ instance_type_class(Ord, f64, methods={
'>=': stdtypes.f64_ord_greater_than_or_equal,
})
Bits = Type3Class('Bits', [a], methods={
Bits = Type3Class('Bits', (a, ), methods={
'shl': [a, u32, a], # Logical shift left
'shr': [a, u32, a], # Logical shift right
'rotl': [a, u32, a], # Rotate bits left
@ -374,7 +380,7 @@ instance_type_class(Bits, u64, methods={
'^': stdtypes.u64_bits_bitwise_xor,
})
NatNum = Type3Class('NatNum', [a], methods={}, operators={
NatNum = Type3Class('NatNum', (a, ), methods={}, operators={
'+': [a, a, a],
'-': [a, a, a],
'*': [a, a, a],
@ -425,7 +431,7 @@ instance_type_class(NatNum, f64, operators={
'>>': stdtypes.f64_natnum_arithmic_shift_right,
})
IntNum = Type3Class('IntNum', [a], methods={
IntNum = Type3Class('IntNum', (a, ), methods={
'abs': [a, a],
'neg': [a, a],
}, operators={}, inherited_classes=[NatNum])
@ -447,7 +453,7 @@ instance_type_class(IntNum, f64, methods={
'neg': stdtypes.f64_intnum_neg,
})
Integral = Type3Class('Eq', [a], methods={
Integral = Type3Class('Eq', (a, ), methods={
}, operators={
'//': [a, a, a],
'%': [a, a, a],
@ -470,7 +476,7 @@ instance_type_class(Integral, i64, operators={
'%': stdtypes.i64_integral_rem,
})
Fractional = Type3Class('Fractional', [a], methods={
Fractional = Type3Class('Fractional', (a, ), methods={
'ceil': [a, a],
'floor': [a, a],
'trunc': [a, a],
@ -496,7 +502,7 @@ instance_type_class(Fractional, f64, methods={
'/': stdtypes.f64_fractional_div,
})
Floating = Type3Class('Floating', [a], methods={
Floating = Type3Class('Floating', (a, ), methods={
'sqrt': [a, a],
}, operators={}, inherited_classes=[Fractional])
@ -509,7 +515,7 @@ instance_type_class(Floating, f64, methods={
'sqrt': stdtypes.f64_floating_sqrt,
})
Sized_ = Type3Class('Sized', [a], methods={
Sized_ = Type3Class('Sized', (a, ), methods={
'len': [a, u32],
}, operators={}) # FIXME: Once we get type class families, add [] here
@ -517,7 +523,7 @@ instance_type_class(Sized_, bytes_, methods={
'len': stdtypes.bytes_sized_len,
})
Extendable = Type3Class('Extendable', [a, b], methods={
Extendable = Type3Class('Extendable', (a, b, ), methods={
'extend': [a, b],
'wrap': [b, a],
}, operators={})
@ -547,7 +553,7 @@ instance_type_class(Extendable, i32, i64, methods={
'wrap': stdtypes.i32_i64_wrap,
})
Promotable = Type3Class('Promotable', [a, b], methods={
Promotable = Type3Class('Promotable', (a, b, ), methods={
'promote': [a, b],
'demote': [b, a],
}, operators={})
@ -557,6 +563,16 @@ instance_type_class(Promotable, f32, f64, methods={
'demote': stdtypes.f32_f64_demote,
})
Foldable = Type3Class('Foldable', (t, ), methods={
'sum': [t(a), a],
}, operators={}, additional_context={
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
})
instance_type_class(Foldable, static_array, methods={
'sum': stdtypes.static_array_i32_4_sum,
})
PRELUDE_TYPE_CLASSES = {
'Eq': Eq,
'Ord': Ord,
@ -592,4 +608,5 @@ PRELUDE_METHODS = {
**Sized_.methods,
**Extendable.methods,
**Promotable.methods,
**Foldable.methods,
}

View File

@ -384,6 +384,50 @@ def __u8_rotr__(g: Generator, x: i32, r: i32) -> i32:
return i32('return') # To satisfy mypy
@func_wrapper()
def __sa_i32_sum__(g: Generator, adr: i32, arlen: i32) -> i32:
i32_size = 4
s = i32('s')
stop = i32('stop')
# stop = adr + ar_len * i32_size
g.local.get(adr)
g.local.get(arlen)
g.i32.const(i32_size)
g.i32.mul()
g.i32.add()
g.local.set(stop)
# sum = 0
g.i32.const(0)
g.local.set(s)
with g.loop():
# sum = sum + *adr
g.local.get(adr)
g.i32.load()
g.local.get(s)
g.i32.add()
g.local.set(s)
# adr = adr + i32_size
g.local.get(adr)
g.i32.const(i32_size)
g.i32.add()
g.local.tee(adr)
# loop if adr < stop
g.local.get(stop)
g.i32.lt_u()
g.br_if(0)
# return sum
g.local.get(s)
g.return_()
return i32('return') # To satisfy mypy
## ###
## class Eq
@ -920,3 +964,11 @@ def f32_f64_promote(g: Generator) -> None:
def f32_f64_demote(g: Generator) -> None:
g.f32.demote_f64()
def static_array_i32_4_sum(g: Generator) -> None:
g.i32.const(4)
g.add_statement('call $stdlib.types.__sa_i32_sum__')
def static_array_i32_5_sum(g: Generator) -> None:
g.i32.const(5)
g.add_statement('call $stdlib.types.__sa_i32_sum__')

View File

@ -3,7 +3,7 @@ This module contains possible constraints generated based on the AST
These need to be resolved before the program can be compiled.
"""
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from .. import ourlang, prelude
from . import placeholders, typeclasses, types
@ -50,7 +50,7 @@ class Context:
__slots__ = ('type_class_instances_existing', )
# Constraint_TypeClassInstanceExists
type_class_instances_existing: set[tuple[typeclasses.Type3Class, tuple[types.Type3, ...]]]
type_class_instances_existing: set[tuple[typeclasses.Type3Class, tuple[Union[types.Type3, types.TypeConstructor_Base[Any], types.TypeConstructor_Struct], ...]]]
def __init__(self) -> None:
self.type_class_instances_existing = set()
@ -229,7 +229,7 @@ class MustImplementTypeClassConstraint(ConstraintBase):
self.types = types
def check(self) -> CheckResult:
typ_list = []
typ_list: list[types.Type3 | types.TypeConstructor_Base[Any] | types.TypeConstructor_Struct] = []
for typ in self.types:
if isinstance(typ, placeholders.PlaceholderForType) and typ.resolve_as is not None:
typ = typ.resolve_as
@ -237,7 +237,15 @@ class MustImplementTypeClassConstraint(ConstraintBase):
if isinstance(typ, placeholders.PlaceholderForType):
return RequireTypeSubstitutes()
if isinstance(typ.application, (types.TypeApplication_Nullary, types.TypeApplication_Struct, )):
typ_list.append(typ)
continue
if isinstance(typ.application, (types.TypeApplication_TypeInt, types.TypeApplication_TypeStar)):
typ_list.append(typ.application.constructor)
continue
raise NotImplementedError(typ, typ.application)
assert len(typ_list) == len(self.types)

View File

@ -39,21 +39,48 @@ def constant(ctx: Context, inp: ourlang.Constant, phft: placeholders.Placeholder
raise NotImplementedError(constant, inp)
def expression(ctx: Context, inp: ourlang.Expression, phft: placeholders.PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp, ourlang.Constant):
yield from constant(ctx, inp, phft)
return
def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderForType) -> ConstraintGenerator:
return _expression_function_call(
ctx,
f'({inp.operator.name})',
inp.operator.signature,
[inp.left, inp.right],
inp,
phft,
)
if isinstance(inp, ourlang.VariableReference):
yield SameTypeConstraint(inp.variable.type3, phft,
comment=f'typeOf("{inp.variable.name}") == typeOf({inp.variable.name})')
return
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator:
return _expression_function_call(
ctx,
inp.function.name,
inp.function.signature,
inp.arguments,
inp,
phft,
)
if isinstance(inp, ourlang.BinaryOp) or isinstance(inp, ourlang.FunctionCall):
signature = inp.operator.signature if isinstance(inp, ourlang.BinaryOp) else inp.function.signature
arguments = [inp.left, inp.right] if isinstance(inp, ourlang.BinaryOp) else inp.arguments
def _expression_function_call(
ctx: Context,
func_name: str,
signature: functions.FunctionSignature,
arguments: list[ourlang.Expression],
return_expr: ourlang.Expression,
return_phft: PlaceholderForType,
) -> ConstraintGenerator:
"""
Generates all type-level constraints for a function call.
func_name = f'({inp.operator.name})' if isinstance(inp, ourlang.BinaryOp) else inp.function.name
A Binary operator functions pretty much the same as a function call
with two arguments - it's only a syntactic difference.
"""
arg_placeholders = {
arg_expr: PlaceholderForType([arg_expr])
for arg_expr in arguments
}
arg_placeholders[return_expr] = return_phft
for call_arg in arguments:
yield from expression(ctx, call_arg, arg_placeholders[call_arg])
type_var_map = {
x: placeholders.PlaceholderForType([])
@ -61,11 +88,7 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: placeholders.Placeho
if isinstance(x, functions.TypeVariable)
}
arg_placeholders = {
arg_expr: PlaceholderForType([arg_expr])
for arg_expr in arguments
}
arg_placeholders[inp] = phft
print('type_var_map', type_var_map)
for arg_expr in arguments:
yield from expression(ctx, arg_expr, arg_placeholders[arg_expr])
@ -81,7 +104,7 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: placeholders.Placeho
raise NotImplementedError(constraint)
for arg_no, (sig_part, arg_expr) in enumerate(zip(signature.args, arguments + [inp], strict=True)):
for arg_no, (sig_part, arg_expr) in enumerate(zip(signature.args, arguments + [return_expr], strict=True)):
if arg_no == len(arguments):
comment = f'The type of a function call to {func_name} is the same as the type that the function returns'
else:
@ -98,6 +121,24 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: placeholders.Placeho
raise NotImplementedError(sig_part)
return
def expression(ctx: Context, inp: ourlang.Expression, phft: placeholders.PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp, ourlang.Constant):
yield from constant(ctx, inp, phft)
return
if isinstance(inp, ourlang.VariableReference):
yield SameTypeConstraint(inp.variable.type3, phft,
comment=f'typeOf("{inp.variable.name}") == typeOf({inp.variable.name})')
return
if isinstance(inp, ourlang.BinaryOp):
yield from expression_binary_op(ctx, inp, phft)
return
if isinstance(inp, ourlang.FunctionCall):
yield from expression_function_call(ctx, inp, phft)
return
if isinstance(inp, ourlang.TupleInstantiation):
r_type = []
for arg in inp.elements:

View File

@ -19,9 +19,18 @@ class TypeVariable:
letter: str
def __init__(self, letter: str) -> None:
assert len(letter) == 1, f'{letter} is not a valid type variable'
self.letter = letter
def deconstruct(self) -> 'TypeConstructorVariable | None':
letter_list = self.letter.split(' ')
if len(letter_list) == 1:
return None
if len(letter_list) == 2:
return TypeConstructorVariable(letter_list[0])
raise NotImplementedError(letter_list)
def __hash__(self) -> int:
return hash(self.letter)
@ -34,6 +43,38 @@ class TypeVariable:
def __repr__(self) -> str:
return f'TypeVariable({repr(self.letter)})'
class TypeConstructorVariable:
"""
Types constructor variable are used in function definition.
They are a lot like TypeVariable, except that they represent a
type constructor rather than a type directly.
For now, we only have type constructor variables for kind
* -> *.
"""
__slots__ = ('letter', )
letter: str
def __init__(self, letter: str) -> None:
self.letter = letter
def __hash__(self) -> int:
return hash((self.letter, ))
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeConstructorVariable):
raise NotImplementedError
return (self.letter == other.letter)
def __call__(self, tvar: TypeVariable) -> 'TypeVariable':
return TypeVariable(self.letter + ' ' + tvar.letter)
def __repr__(self) -> str:
return f'TypeConstructorVariable({self.letter!r})'
class ConstraintBase:
__slots__ = ()
@ -62,13 +103,11 @@ class TypeVariableContext:
constraints: list[ConstraintBase]
def __init__(self) -> None:
self.constraints = []
def __init__(self, constraints: Iterable[ConstraintBase] = ()) -> None:
self.constraints = list(constraints)
def __copy__(self) -> 'TypeVariableContext':
result = TypeVariableContext()
result.constraints.extend(self.constraints)
return result
return TypeVariableContext(self.constraints)
def __str__(self) -> str:
if not self.constraints:

View File

@ -2,7 +2,9 @@ from typing import Dict, Iterable, List, Mapping, Optional, Union
from .functions import (
Constraint_TypeClassInstanceExists,
ConstraintBase,
FunctionSignature,
TypeConstructorVariable,
TypeVariable,
TypeVariableContext,
)
@ -29,7 +31,7 @@ class Type3Class:
__slots__ = ('name', 'args', 'methods', 'operators', 'inherited_classes', )
name: str
args: List[TypeVariable]
args: tuple[TypeVariable] | tuple[TypeVariable, TypeVariable] | tuple[TypeConstructorVariable]
methods: Dict[str, Type3ClassMethod]
operators: Dict[str, Type3ClassMethod]
inherited_classes: List['Type3Class']
@ -37,26 +39,60 @@ class Type3Class:
def __init__(
self,
name: str,
args: Iterable[TypeVariable],
args: tuple[TypeVariable] | tuple[TypeVariable, TypeVariable] | tuple[TypeConstructorVariable],
methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]],
operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]],
inherited_classes: Optional[List['Type3Class']] = None,
additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None,
) -> None:
self.name = name
self.args = list(args)
context = TypeVariableContext()
context.constraints.append(Constraint_TypeClassInstanceExists(self, args))
self.args = args
self.methods = {
k: Type3ClassMethod(k, FunctionSignature(context, v))
k: Type3ClassMethod(k, _create_signature(v, self))
for k, v in methods.items()
}
self.operators = {
k: Type3ClassMethod(k, FunctionSignature(context, v))
k: Type3ClassMethod(k, _create_signature(v, self))
for k, v in operators.items()
}
self.inherited_classes = inherited_classes or []
if additional_context:
for func_name, constraint_list in additional_context.items():
func = self.methods.get(func_name) or self.operators.get(func_name)
assert func is not None # type hint
func.signature.context.constraints.extend(constraint_list)
def __repr__(self) -> str:
return self.name
def _create_signature(
method_arg_list: Iterable[Type3 | TypeVariable],
type_class3: Type3Class,
) -> FunctionSignature:
context = TypeVariableContext()
if not isinstance(type_class3.args[0], TypeConstructorVariable):
context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args))
signature_args: list[Type3 | TypeVariable] = []
for method_arg in method_arg_list:
if isinstance(method_arg, Type3):
signature_args.append(method_arg)
continue
if isinstance(method_arg, TypeVariable):
type_constructor = method_arg.deconstruct()
if type_constructor is None:
signature_args.append(method_arg)
continue
if (type_constructor, ) == type_class3.args:
context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, [method_arg]))
signature_args.append(method_arg)
continue
raise NotImplementedError(method_arg)
return FunctionSignature(context, signature_args)

View File

@ -0,0 +1,55 @@
import pytest
from phasm.type3.entry import Type3Exception
from ..helpers import Suite
@pytest.mark.integration_test
def test_foldable_sum():
code_py = """
@exported
def testEntry(x: i32[5]) -> i32:
return sum(x)
"""
result = Suite(code_py).run_code((4, 5, 6, 7, 8, ))
assert 30 == result.returned_value
@pytest.mark.integration_test
def test_foldable_sum_not_natnum():
code_py = """
class Foo:
bar: i32
@exported
def testEntry(x: Foo[4]) -> Foo:
return sum(x)
"""
with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_foldable_invalid_return_type():
code_py = """
@exported
def testEntry(x: i32[5]) -> f64:
return sum(x)
"""
with pytest.raises(Type3Exception, match='f64 must be i32 instead'):
Suite(code_py).run_code((4, 5, 6, 7, 8, ))
@pytest.mark.integration_test
def test_foldable_not_foldable():
code_py = """
@exported
def testEntry(x: i32) -> i32:
return sum(x)
"""
with pytest.raises(Type3Exception, match='Missing type class instantation: Foldable i32'):
Suite(code_py).run_code()