Removes the special casing for foldl

Had to implement both functions as arguments and type
place holders (variables) for type constructors.

Had to implement functions as a type as well.

Still have to figure out how to pass functions around.
This commit is contained in:
Johan B.W. de Vries 2025-04-27 12:54:34 +02:00
parent ac4b46bbe7
commit f19decf65c
13 changed files with 276 additions and 244 deletions

View File

@ -102,10 +102,6 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}' return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements: def statement(inp: ourlang.Statement) -> Statements:

View File

@ -4,7 +4,7 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
import struct import struct
from typing import List, Optional from typing import List, Optional
from . import codestyle, ourlang, prelude, wasm from . import ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
@ -376,90 +376,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
))) )))
return return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp)
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
"""
Compile: Fold expression
"""
assert inp.type3 is not None, TYPE3_ASSERTION_ERROR
if inp.iter.type3 is not prelude.bytes_:
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
wgn.add_statement('nop', comment='acu :: u8')
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
wgn.add_statement('nop', comment='adr :: bytes*')
adr_var = wgn.temp_var_i32('fold_i32_adr')
wgn.add_statement('nop', comment='len :: i32')
len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base')
expression(wgn, inp.base)
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, inp.iter)
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)')
wgn.local.get(adr_var)
wgn.i32.load()
wgn.local.set(len_var)
wgn.add_statement('nop', comment='i = 0')
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
wgn.i32.const(0)
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i < len')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
with wgn.if_():
# From here on, adr_var is the address of byte we're referencing
# This is akin to calling stdlib_types.__subscript_bytes__
# But since we already know we are inside of bounds,
# can just bypass it and load the memory directly.
wgn.local.get(adr_var)
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
wgn.i32.add()
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='while True')
with wgn.loop():
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
wgn.local.get(acu_var)
# Get the next byte, write back the address
wgn.local.get(adr_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.tee(adr_var)
wgn.i32.load8_u()
wgn.add_statement('call', f'${inp.func.name}')
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='i = i + 1')
wgn.local.get(idx_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i >= len: break')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
wgn.br_if(0)
# return acu
wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None: def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
""" """
Compile: Return statement Compile: Return statement

View File

@ -1,7 +1,6 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
import enum
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from . import prelude from . import prelude
@ -161,6 +160,18 @@ class FunctionCall(Expression):
self.function = function self.function = function
self.arguments = [] self.arguments = []
class FunctionReference(Expression):
"""
An function reference expression within a statement
"""
__slots__ = ('function', )
function: 'Function'
def __init__(self, function: 'Function') -> None:
super().__init__()
self.function = function
class TupleInstantiation(Expression): class TupleInstantiation(Expression):
""" """
Instantiation a tuple Instantiation a tuple
@ -207,36 +218,6 @@ class AccessStructMember(Expression):
self.struct_type3 = struct_type3 self.struct_type3 = struct_type3
self.member = member self.member = member
class Fold(Expression):
"""
A (left or right) fold
"""
class Direction(enum.Enum):
"""
Which direction to fold in
"""
LEFT = 0
RIGHT = 1
dir: Direction
func: 'Function'
base: Expression
iter: Expression
def __init__(
self,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__()
self.dir = dir_
self.func = func
self.base = base
self.iter = iter_
class Statement: class Statement:
""" """
A statement within a function A statement within a function

View File

@ -14,10 +14,10 @@ from .ourlang import (
ConstantStruct, ConstantStruct,
ConstantTuple, ConstantTuple,
Expression, Expression,
Fold,
Function, Function,
FunctionCall, FunctionCall,
FunctionParam, FunctionParam,
FunctionReference,
Module, Module,
ModuleConstantDef, ModuleConstantDef,
ModuleDataBlock, ModuleDataBlock,
@ -446,6 +446,9 @@ class OurVisitor:
cdef = module.constant_defs[node.id] cdef = module.constant_defs[node.id]
return VariableReference(cdef) return VariableReference(cdef)
if node.id in module.functions:
return FunctionReference(module.functions[node.id])
_raise_static_error(node, f'Undefined variable {node.id}') _raise_static_error(node, f'Undefined variable {node.id}')
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
@ -462,7 +465,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as expr in FunctionDef') raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]: def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
if node.keywords: if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet? _raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -475,28 +478,6 @@ class OurVisitor:
if node.func.id in PRELUDE_METHODS: if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id] func = PRELUDE_METHODS[node.func.id]
elif node.func.id == 'foldl':
if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic, you cannot return a function
subnode = node.args[0]
if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id]
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
return Fold(
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
)
else: else:
if node.func.id not in module.functions: if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function') _raise_static_error(node, 'Call to undefined function')

View File

@ -20,6 +20,7 @@ from ..type3.types import (
Type3, Type3,
TypeApplication_Nullary, TypeApplication_Nullary,
TypeConstructor_Base, TypeConstructor_Base,
TypeConstructor_Function,
TypeConstructor_StaticArray, TypeConstructor_StaticArray,
TypeConstructor_Struct, TypeConstructor_Struct,
TypeConstructor_Tuple, TypeConstructor_Tuple,
@ -186,6 +187,16 @@ It should be applied with zero or more arguments. It has a compile time
determined length, and each argument can be different. determined length, and each argument can be different.
""" """
def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None:
pass # ? instance_type_class(InternalPassAsPointer, typ)
function = TypeConstructor_Function('function', on_create=fn_on_create)
"""
This is a function.
It should be applied with one or more arguments. The last argument is the 'return' type.
"""
def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None: def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None:
instance_type_class(InternalPassAsPointer, typ) instance_type_class(InternalPassAsPointer, typ)
@ -574,12 +585,16 @@ instance_type_class(Promotable, f32, f64, methods={
Foldable = Type3Class('Foldable', (t, ), methods={ Foldable = Type3Class('Foldable', (t, ), methods={
'sum': [t(a), a], 'sum': [t(a), a],
'foldl': [[a, b, b], b, t(a), b],
'foldr': [[a, b, b], b, t(a), b],
}, operators={}, additional_context={ }, operators={}, additional_context={
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))], 'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
}) })
instance_type_class(Foldable, static_array, methods={ instance_type_class(Foldable, static_array, methods={
'sum': stdtypes.static_array_sum, 'sum': stdtypes.static_array_sum,
'foldl': stdtypes.static_array_foldl,
'foldr': stdtypes.static_array_foldr,
}) })
PRELUDE_TYPE_CLASSES = { PRELUDE_TYPE_CLASSES = {

View File

@ -1185,3 +1185,9 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]') g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]')
# End result: [sum] # End result: [sum]
def static_array_foldl(g: Generator, tv_map: TypeVariableLookup) -> None:
raise NotImplementedError(tv_map)
def static_array_foldr(g: Generator, tv_map: TypeVariableLookup) -> None:
raise NotImplementedError(tv_map)

View File

@ -6,6 +6,7 @@ These need to be resolved before the program can be compiled.
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from .. import ourlang, prelude from .. import ourlang, prelude
from .functions import FunctionArgument, TypeVariable
from .placeholders import PlaceholderForType, Type3OrPlaceholder from .placeholders import PlaceholderForType, Type3OrPlaceholder
from .routers import NoRouteForTypeException, TypeApplicationRouter from .routers import NoRouteForTypeException, TypeApplicationRouter
from .typeclasses import Type3Class from .typeclasses import Type3Class
@ -158,7 +159,7 @@ class SameTypeConstraint(ConstraintBase):
return ( return (
' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))), ' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))),
{ {
't' + str(idx): typ 't' + str(idx): typ.name if isinstance(typ, ourlang.Function) else typ
for idx, typ in enumerate(self.type_list) for idx, typ in enumerate(self.type_list)
}, },
) )
@ -181,7 +182,7 @@ class SameTypeArgumentConstraint(ConstraintBase):
self.arg_var = arg_var self.arg_var = arg_var
def check(self) -> CheckResult: def check(self) -> CheckResult:
if self.tc_var.resolve_as is None or self.arg_var.resolve_as is None: if self.tc_var.resolve_as is None:
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
tc_typ = self.tc_var.resolve_as tc_typ = self.tc_var.resolve_as
@ -201,13 +202,84 @@ class SameTypeArgumentConstraint(ConstraintBase):
# FIXME: This feels sketchy. Shouldn't the type variable # FIXME: This feels sketchy. Shouldn't the type variable
# have the exact same number as arguments? # have the exact same number as arguments?
if isinstance(tc_typ.application, TypeApplication_TypeInt): if isinstance(tc_typ.application, TypeApplication_TypeInt):
if tc_typ.application.arguments[0] == arg_typ: return [SameTypeConstraint(
return None tc_typ.application.arguments[0],
self.arg_var,
return Error(f'{tc_typ.application.arguments[0]:s} must be {arg_typ:s} instead') comment=self.comment,
)]
raise NotImplementedError(tc_typ, arg_typ) raise NotImplementedError(tc_typ, arg_typ)
def human_readable(self) -> HumanReadableRet:
return (
'{tc_var}` == {arg_var}',
{
'tc_var': self.tc_var if self.tc_var.resolve_as is None else self.tc_var,
'arg_var': self.arg_var if self.arg_var.resolve_as is None else self.arg_var,
},
)
class SameFunctionArgumentConstraint(ConstraintBase):
__slots__ = ('type3', 'func_arg', 'type_var_map', )
type3: PlaceholderForType
func_arg: FunctionArgument
type_var_map: dict[TypeVariable, PlaceholderForType]
def __init__(self, type3: PlaceholderForType, func_arg: FunctionArgument, type_var_map: dict[TypeVariable, PlaceholderForType], *, comment: str) -> None:
super().__init__(comment=comment)
self.type3 = type3
self.func_arg = func_arg
self.type_var_map = type_var_map
def check(self) -> CheckResult:
if self.type3.resolve_as is None:
return RequireTypeSubstitutes()
typ = self.type3.resolve_as
if isinstance(typ.application, TypeApplication_Nullary):
return Error(f'{typ:s} must be a function instead')
if not isinstance(typ.application, TypeApplication_TypeStar):
return Error(f'{typ:s} must be a function instead')
type_var_map = {
x: y.resolve_as
for x, y in self.type_var_map.items()
if y.resolve_as is not None
}
exp_type_arg_list = [
tv if isinstance(tv, Type3) else type_var_map[tv]
for tv in self.func_arg.args
if isinstance(tv, Type3) or tv in type_var_map
]
print('self.func_arg.args', self.func_arg.args)
print('exp_type_arg_list', exp_type_arg_list)
if len(exp_type_arg_list) != len(self.func_arg.args):
return RequireTypeSubstitutes()
return [
SameTypeConstraint(
typ,
prelude.function(*exp_type_arg_list),
comment=self.comment,
)
]
def human_readable(self) -> HumanReadableRet:
return (
'{type3} == {func_arg}',
{
'type3': self.type3,
'func_arg': self.func_arg.name,
},
)
class TupleMatchConstraint(ConstraintBase): class TupleMatchConstraint(ConstraintBase):
__slots__ = ('exp_type', 'args', ) __slots__ = ('exp_type', 'args', )
@ -264,14 +336,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
__slots__ = ('context', 'type_class3', 'types', ) __slots__ = ('context', 'type_class3', 'types', )
context: Context context: Context
type_class3: Union[str, Type3Class] type_class3: Type3Class
types: list[Type3OrPlaceholder] types: list[Type3OrPlaceholder]
DATA = { def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
'bytes': {'Foldable'},
}
def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
self.context = context self.context = context
@ -299,13 +367,9 @@ class MustImplementTypeClassConstraint(ConstraintBase):
assert len(typ_list) == len(self.types) assert len(typ_list) == len(self.types)
if isinstance(self.type_class3, Type3Class): key = (self.type_class3, tuple(typ_list), )
key = (self.type_class3, tuple(typ_list), ) if key in self.context.type_class_instances_existing:
if key in self.context.type_class_instances_existing: return None
return None
else:
if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()):
return None
typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name
typ_name_list = ' '.join(x.name for x in typ_list) typ_name_list = ' '.join(x.name for x in typ_list)

View File

@ -12,12 +12,14 @@ from .constraints import (
Context, Context,
LiteralFitsConstraint, LiteralFitsConstraint,
MustImplementTypeClassConstraint, MustImplementTypeClassConstraint,
SameFunctionArgumentConstraint,
SameTypeArgumentConstraint, SameTypeArgumentConstraint,
SameTypeConstraint, SameTypeConstraint,
TupleMatchConstraint, TupleMatchConstraint,
) )
from .functions import ( from .functions import (
Constraint_TypeClassInstanceExists, Constraint_TypeClassInstanceExists,
FunctionArgument,
FunctionSignature, FunctionSignature,
TypeVariable, TypeVariable,
TypeVariableApplication_Unary, TypeVariableApplication_Unary,
@ -111,6 +113,33 @@ def _expression_function_call(
raise NotImplementedError(constraint) raise NotImplementedError(constraint)
func_var_map = {
x: PlaceholderForType([])
for x in signature.args
if isinstance(x, FunctionArgument)
}
# If some of the function arguments are functions,
# we need to deal with those separately.
for sig_arg in signature.args:
if not isinstance(sig_arg, FunctionArgument):
continue
# Ensure that for all type variables in the function
# there are also type variables available
for func_arg in sig_arg.args:
if isinstance(func_arg, Type3):
continue
type_var_map.setdefault(func_arg, PlaceholderForType([]))
yield SameFunctionArgumentConstraint(
func_var_map[sig_arg],
sig_arg,
type_var_map,
comment=f'Ensure `{sig_arg.name}` matches in {signature}',
)
# If some of the function arguments are type constructors, # If some of the function arguments are type constructors,
# we need to deal with those separately. # we need to deal with those separately.
# That is, given `foo :: t a -> a` we need to ensure # That is, given `foo :: t a -> a` we need to ensure
@ -120,6 +149,9 @@ def _expression_function_call(
# Not a type variable at all # Not a type variable at all
continue continue
if isinstance(sig_arg, FunctionArgument):
continue
if sig_arg.application.constructor is None: if sig_arg.application.constructor is None:
# Not a type variable for a type constructor # Not a type variable for a type constructor
continue continue
@ -150,9 +182,20 @@ def _expression_function_call(
yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment) yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment)
continue continue
if isinstance(sig_part, FunctionArgument):
yield SameTypeConstraint(func_var_map[sig_part], arg_placeholders[arg_expr], comment=comment)
continue
raise NotImplementedError(sig_part) raise NotImplementedError(sig_part)
return return
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
yield SameTypeConstraint(
prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3),
phft,
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
)
def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator: def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp, ourlang.Constant): if isinstance(inp, ourlang.Constant):
yield from constant(ctx, inp, phft) yield from constant(ctx, inp, phft)
@ -171,6 +214,10 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
yield from expression_function_call(ctx, inp, phft) yield from expression_function_call(ctx, inp, phft)
return return
if isinstance(inp, ourlang.FunctionReference):
yield from expression_function_reference(ctx, inp, phft)
return
if isinstance(inp, ourlang.TupleInstantiation): if isinstance(inp, ourlang.TupleInstantiation):
r_type = [] r_type = []
for arg in inp.elements: for arg in inp.elements:
@ -209,19 +256,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
return return
if isinstance(inp, ourlang.Fold):
base_phft = PlaceholderForType([inp.base])
iter_phft = PlaceholderForType([inp.iter])
yield from expression(ctx, inp.base, base_phft)
yield from expression(ctx, inp.iter, iter_phft)
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft,
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft])
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:

View File

@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List, Union from __future__ import annotations
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List
if TYPE_CHECKING: if TYPE_CHECKING:
from .typeclasses import Type3Class from .typeclasses import Type3Class
@ -155,15 +157,29 @@ class TypeVariableContext:
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TypeVariableContext({self.constraints!r})' return f'TypeVariableContext({self.constraints!r})'
class FunctionArgument:
__slots__ = ('args', 'name', )
args: list[Type3 | TypeVariable]
name: str
def __init__(self, args: list[Type3 | TypeVariable]) -> None:
self.args = args
self.name = '(' + ' -> '.join(x.name for x in args) + ')'
class FunctionSignature: class FunctionSignature:
__slots__ = ('context', 'args', ) __slots__ = ('context', 'args', )
context: TypeVariableContext context: TypeVariableContext
args: List[Union['Type3', TypeVariable]] args: List[Type3 | TypeVariable | FunctionArgument]
def __init__(self, context: TypeVariableContext, args: Iterable[Union['Type3', TypeVariable]]) -> None: def __init__(self, context: TypeVariableContext, args: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]) -> None:
self.context = context.__copy__() self.context = context.__copy__()
self.args = list(args) self.args = list(
FunctionArgument(x) if isinstance(x, list) else x
for x in args
)
def __str__(self) -> str: def __str__(self) -> str:
return str(self.context) + ' -> '.join(x.name for x in self.args) return str(self.context) + ' -> '.join(x.name for x in self.args)

View File

@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Mapping, Optional, Union from typing import Dict, Iterable, List, Mapping, Optional
from .functions import ( from .functions import (
Constraint_TypeClassInstanceExists, Constraint_TypeClassInstanceExists,
@ -42,8 +42,8 @@ class Type3Class:
self, self,
name: str, name: str,
args: Type3ClassArgs, args: Type3ClassArgs,
methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]], methods: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]], operators: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
inherited_classes: Optional[List['Type3Class']] = None, inherited_classes: Optional[List['Type3Class']] = None,
additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None, additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None,
) -> None: ) -> None:
@ -71,19 +71,23 @@ class Type3Class:
return self.name return self.name
def _create_signature( def _create_signature(
method_arg_list: Iterable[Type3 | TypeVariable], method_arg_list: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]],
type_class3: Type3Class, type_class3: Type3Class,
) -> FunctionSignature: ) -> FunctionSignature:
context = TypeVariableContext() context = TypeVariableContext()
if not isinstance(type_class3.args[0], TypeConstructorVariable): if not isinstance(type_class3.args[0], TypeConstructorVariable):
context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args)) context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args))
signature_args: list[Type3 | TypeVariable] = [] signature_args: list[Type3 | TypeVariable | list[Type3 | TypeVariable]] = []
for method_arg in method_arg_list: for method_arg in method_arg_list:
if isinstance(method_arg, Type3): if isinstance(method_arg, Type3):
signature_args.append(method_arg) signature_args.append(method_arg)
continue continue
if isinstance(method_arg, list):
signature_args.append(method_arg)
continue
if isinstance(method_arg, TypeVariable): if isinstance(method_arg, TypeVariable):
type_constructor = method_arg.application.constructor type_constructor = method_arg.application.constructor
if type_constructor is None: if type_constructor is None:

View File

@ -239,6 +239,10 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar):
def make_name(self, key: Tuple[Type3, ...]) -> str: def make_name(self, key: Tuple[Type3, ...]) -> str:
return '(' + ', '.join(x.name for x in key) + ', )' return '(' + ', '.join(x.name for x in key) + ', )'
class TypeConstructor_Function(TypeConstructor_TypeStar):
def make_name(self, key: Tuple[Type3, ...]) -> str:
return '(' + ' -> '.join(x.name for x in key) + ')'
class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]): class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]):
""" """
Constructs struct types Constructs struct types

View File

@ -1,61 +0,0 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
def test_foldl_1():
code_py = """
def u8_or(l: u8, r: u8) -> u8:
return l | r
@exported
def testEntry(b: bytes) -> u8:
return foldl(u8_or, 128, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'')
assert 128 == result.returned_value
result = suite.run_code(b'\x80')
assert 128 == result.returned_value
result = suite.run_code(b'\x80\x40')
assert 192 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10')
assert 240 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
assert 255 == result.returned_value
@pytest.mark.integration_test
def test_foldl_2():
code_py = """
def xor(l: u8, r: u8) -> u8:
return l ^ r
@exported
def testEntry(a: bytes, b: bytes) -> u8:
return foldl(xor, 0, a) ^ foldl(xor, 0, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'\x55\x0F', b'\x33\x80')
assert 233 == result.returned_value
@pytest.mark.integration_test
def test_foldl_3():
code_py = """
def xor(l: u32, r: u8) -> u32:
return l ^ extend(r)
@exported
def testEntry(a: bytes) -> u32:
return foldl(xor, 0, a)
"""
suite = Suite(code_py)
result = suite.run_code(b'\x55\x0F\x33\x80')
assert 233 == result.returned_value

View File

@ -36,6 +36,80 @@ def testEntry(x: Foo[4]) -> Foo:
with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'): with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test
def test_foldable_foldl_size():
code_py = """
def u8_or(l: u8, r: u8) -> u8:
return l | r
@exported
def testEntry(b: bytes) -> u8:
return foldl(u8_or, 128, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'')
assert 128 == result.returned_value
result = suite.run_code(b'\x80')
assert 128 == result.returned_value
result = suite.run_code(b'\x80\x40')
assert 192 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10')
assert 240 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
assert 255 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('direction, exp_result', [
('foldl', -55, ),
('foldr', -5, ),
])
def test_foldable_foldl_foldr(direction, exp_result):
# See https://stackoverflow.com/a/13280185
code_py = f"""
def i32_sub(l: i32, r: i32) -> i32:
return l - r
@exported
def testEntry(b: i32[10]) -> i32:
return {direction}(i32_sub, 0, b)
"""
suite = Suite(code_py)
result = suite.run_code(tuple(range(1, 10)))
assert exp_result == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('in_typ', ['i8', 'i8[3]'])
def test_foldable_argument_must_be_a_function(in_typ):
code_py = f"""
@exported
def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
return foldl(x, y, z)
"""
r_in_typ = in_typ.replace('[', '\\[').replace(']', '\\]')
with pytest.raises(Type3Exception, match=f'{r_in_typ} must be a function instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_foldable_argument_must_be_right_function():
code_py = """
def foo(l: i32, r: i64) -> i64:
return extend(l) + r
@exported
def testEntry(i: i64, l: i64[3]) -> i64:
return foldr(foo, i, l)
"""
with pytest.raises(Type3Exception, match=r'\(i64 -> i64 -> i64\) must be \(i32 -> i64 -> i64\) instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
def test_foldable_invalid_return_type(): def test_foldable_invalid_return_type():