diff --git a/phasm/codestyle.py b/phasm/codestyle.py index fc40868..10fd276 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -102,10 +102,6 @@ def expression(inp: ourlang.Expression) -> str: if isinstance(inp, ourlang.AccessStructMember): return f'{expression(inp.varref)}.{inp.member}' - if isinstance(inp, ourlang.Fold): - fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' - return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' - raise NotImplementedError(expression, inp) def statement(inp: ourlang.Statement) -> Statements: diff --git a/phasm/compiler.py b/phasm/compiler.py index d098499..656b269 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -4,7 +4,7 @@ This module contains the code to convert parsed Ourlang into WebAssembly code import struct from typing import List, Optional -from . import codestyle, ourlang, prelude, wasm +from . import ourlang, prelude, wasm from .runtime import calculate_alloc_size, calculate_member_offset from .stdlib import alloc as stdlib_alloc from .stdlib import types as stdlib_types @@ -376,90 +376,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: ))) return - if isinstance(inp, ourlang.Fold): - expression_fold(wgn, inp) - return - raise NotImplementedError(expression, inp) -def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: - """ - Compile: Fold expression - """ - assert inp.type3 is not None, TYPE3_ASSERTION_ERROR - - if inp.iter.type3 is not prelude.bytes_: - raise NotImplementedError(expression_fold, inp, inp.iter.type3) - - wgn.add_statement('nop', comment='acu :: u8') - acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu') - wgn.add_statement('nop', comment='adr :: bytes*') - adr_var = wgn.temp_var_i32('fold_i32_adr') - wgn.add_statement('nop', comment='len :: i32') - len_var = wgn.temp_var_i32('fold_i32_len') - - wgn.add_statement('nop', comment='acu = base') - expression(wgn, inp.base) - wgn.local.set(acu_var) - - wgn.add_statement('nop', comment='adr = adr(iter)') - expression(wgn, inp.iter) - wgn.local.set(adr_var) - - wgn.add_statement('nop', comment='len = len(iter)') - wgn.local.get(adr_var) - wgn.i32.load() - wgn.local.set(len_var) - - wgn.add_statement('nop', comment='i = 0') - idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx') - wgn.i32.const(0) - wgn.local.set(idx_var) - - wgn.add_statement('nop', comment='if i < len') - wgn.local.get(idx_var) - wgn.local.get(len_var) - wgn.i32.lt_u() - with wgn.if_(): - # From here on, adr_var is the address of byte we're referencing - # This is akin to calling stdlib_types.__subscript_bytes__ - # But since we already know we are inside of bounds, - # can just bypass it and load the memory directly. - wgn.local.get(adr_var) - wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop - wgn.i32.add() - wgn.local.set(adr_var) - - wgn.add_statement('nop', comment='while True') - with wgn.loop(): - wgn.add_statement('nop', comment='acu = func(acu, iter[i])') - wgn.local.get(acu_var) - - # Get the next byte, write back the address - wgn.local.get(adr_var) - wgn.i32.const(1) - wgn.i32.add() - wgn.local.tee(adr_var) - wgn.i32.load8_u() - - wgn.add_statement('call', f'${inp.func.name}') - wgn.local.set(acu_var) - - wgn.add_statement('nop', comment='i = i + 1') - wgn.local.get(idx_var) - wgn.i32.const(1) - wgn.i32.add() - wgn.local.set(idx_var) - - wgn.add_statement('nop', comment='if i >= len: break') - wgn.local.get(idx_var) - wgn.local.get(len_var) - wgn.i32.lt_u() - wgn.br_if(0) - - # return acu - wgn.local.get(acu_var) - def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None: """ Compile: Return statement diff --git a/phasm/ourlang.py b/phasm/ourlang.py index df97b23..29c1ce4 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -1,7 +1,6 @@ """ Contains the syntax tree for ourlang """ -import enum from typing import Dict, Iterable, List, Optional, Union from . import prelude @@ -161,6 +160,18 @@ class FunctionCall(Expression): self.function = function self.arguments = [] +class FunctionReference(Expression): + """ + An function reference expression within a statement + """ + __slots__ = ('function', ) + + function: 'Function' + + def __init__(self, function: 'Function') -> None: + super().__init__() + self.function = function + class TupleInstantiation(Expression): """ Instantiation a tuple @@ -207,36 +218,6 @@ class AccessStructMember(Expression): self.struct_type3 = struct_type3 self.member = member -class Fold(Expression): - """ - A (left or right) fold - """ - class Direction(enum.Enum): - """ - Which direction to fold in - """ - LEFT = 0 - RIGHT = 1 - - dir: Direction - func: 'Function' - base: Expression - iter: Expression - - def __init__( - self, - dir_: Direction, - func: 'Function', - base: Expression, - iter_: Expression, - ) -> None: - super().__init__() - - self.dir = dir_ - self.func = func - self.base = base - self.iter = iter_ - class Statement: """ A statement within a function diff --git a/phasm/parser.py b/phasm/parser.py index 9944dc2..2346a71 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -14,10 +14,10 @@ from .ourlang import ( ConstantStruct, ConstantTuple, Expression, - Fold, Function, FunctionCall, FunctionParam, + FunctionReference, Module, ModuleConstantDef, ModuleDataBlock, @@ -446,6 +446,9 @@ class OurVisitor: cdef = module.constant_defs[node.id] return VariableReference(cdef) + if node.id in module.functions: + return FunctionReference(module.functions[node.id]) + _raise_static_error(node, f'Undefined variable {node.id}') if isinstance(node, ast.Tuple): @@ -462,7 +465,7 @@ class OurVisitor: raise NotImplementedError(f'{node} as expr in FunctionDef') - def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]: + def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? @@ -475,28 +478,6 @@ class OurVisitor: if node.func.id in PRELUDE_METHODS: func = PRELUDE_METHODS[node.func.id] - elif node.func.id == 'foldl': - if 3 != len(node.args): - _raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given') - - # TODO: This is not generic, you cannot return a function - subnode = node.args[0] - if not isinstance(subnode, ast.Name): - raise NotImplementedError(f'Calling methods that are not a name {subnode}') - if not isinstance(subnode.ctx, ast.Load): - _raise_static_error(subnode, 'Must be load context') - if subnode.id not in module.functions: - _raise_static_error(subnode, 'Reference to undefined function') - func = module.functions[subnode.id] - if 2 != len(func.posonlyargs): - _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') - - return Fold( - Fold.Direction.LEFT, - func, - self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), - self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), - ) else: if node.func.id not in module.functions: _raise_static_error(node, 'Call to undefined function') diff --git a/phasm/prelude/__init__.py b/phasm/prelude/__init__.py index e269590..3e9672d 100644 --- a/phasm/prelude/__init__.py +++ b/phasm/prelude/__init__.py @@ -20,6 +20,7 @@ from ..type3.types import ( Type3, TypeApplication_Nullary, TypeConstructor_Base, + TypeConstructor_Function, TypeConstructor_StaticArray, TypeConstructor_Struct, TypeConstructor_Tuple, @@ -186,6 +187,16 @@ It should be applied with zero or more arguments. It has a compile time determined length, and each argument can be different. """ +def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None: + pass # ? instance_type_class(InternalPassAsPointer, typ) + +function = TypeConstructor_Function('function', on_create=fn_on_create) +""" +This is a function. + +It should be applied with one or more arguments. The last argument is the 'return' type. +""" + def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None: instance_type_class(InternalPassAsPointer, typ) @@ -574,12 +585,16 @@ instance_type_class(Promotable, f32, f64, methods={ Foldable = Type3Class('Foldable', (t, ), methods={ 'sum': [t(a), a], + 'foldl': [[a, b, b], b, t(a), b], + 'foldr': [[a, b, b], b, t(a), b], }, operators={}, additional_context={ 'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))], }) instance_type_class(Foldable, static_array, methods={ 'sum': stdtypes.static_array_sum, + 'foldl': stdtypes.static_array_foldl, + 'foldr': stdtypes.static_array_foldr, }) PRELUDE_TYPE_CLASSES = { diff --git a/phasm/stdlib/types.py b/phasm/stdlib/types.py index 425dfa0..3e2eb06 100644 --- a/phasm/stdlib/types.py +++ b/phasm/stdlib/types.py @@ -1185,3 +1185,9 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]') # End result: [sum] + +def static_array_foldl(g: Generator, tv_map: TypeVariableLookup) -> None: + raise NotImplementedError(tv_map) + +def static_array_foldr(g: Generator, tv_map: TypeVariableLookup) -> None: + raise NotImplementedError(tv_map) diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index 4a9541b..4372d50 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -6,6 +6,7 @@ These need to be resolved before the program can be compiled. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from .. import ourlang, prelude +from .functions import FunctionArgument, TypeVariable from .placeholders import PlaceholderForType, Type3OrPlaceholder from .routers import NoRouteForTypeException, TypeApplicationRouter from .typeclasses import Type3Class @@ -158,7 +159,7 @@ class SameTypeConstraint(ConstraintBase): return ( ' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))), { - 't' + str(idx): typ + 't' + str(idx): typ.name if isinstance(typ, ourlang.Function) else typ for idx, typ in enumerate(self.type_list) }, ) @@ -181,7 +182,7 @@ class SameTypeArgumentConstraint(ConstraintBase): self.arg_var = arg_var def check(self) -> CheckResult: - if self.tc_var.resolve_as is None or self.arg_var.resolve_as is None: + if self.tc_var.resolve_as is None: return RequireTypeSubstitutes() tc_typ = self.tc_var.resolve_as @@ -201,13 +202,84 @@ class SameTypeArgumentConstraint(ConstraintBase): # FIXME: This feels sketchy. Shouldn't the type variable # have the exact same number as arguments? if isinstance(tc_typ.application, TypeApplication_TypeInt): - if tc_typ.application.arguments[0] == arg_typ: - return None - - return Error(f'{tc_typ.application.arguments[0]:s} must be {arg_typ:s} instead') + return [SameTypeConstraint( + tc_typ.application.arguments[0], + self.arg_var, + comment=self.comment, + )] raise NotImplementedError(tc_typ, arg_typ) + def human_readable(self) -> HumanReadableRet: + return ( + '{tc_var}` == {arg_var}', + { + 'tc_var': self.tc_var if self.tc_var.resolve_as is None else self.tc_var, + 'arg_var': self.arg_var if self.arg_var.resolve_as is None else self.arg_var, + }, + ) + +class SameFunctionArgumentConstraint(ConstraintBase): + __slots__ = ('type3', 'func_arg', 'type_var_map', ) + + type3: PlaceholderForType + func_arg: FunctionArgument + type_var_map: dict[TypeVariable, PlaceholderForType] + + def __init__(self, type3: PlaceholderForType, func_arg: FunctionArgument, type_var_map: dict[TypeVariable, PlaceholderForType], *, comment: str) -> None: + super().__init__(comment=comment) + + self.type3 = type3 + self.func_arg = func_arg + self.type_var_map = type_var_map + + def check(self) -> CheckResult: + if self.type3.resolve_as is None: + return RequireTypeSubstitutes() + + typ = self.type3.resolve_as + + if isinstance(typ.application, TypeApplication_Nullary): + return Error(f'{typ:s} must be a function instead') + + if not isinstance(typ.application, TypeApplication_TypeStar): + return Error(f'{typ:s} must be a function instead') + + type_var_map = { + x: y.resolve_as + for x, y in self.type_var_map.items() + if y.resolve_as is not None + } + + exp_type_arg_list = [ + tv if isinstance(tv, Type3) else type_var_map[tv] + for tv in self.func_arg.args + if isinstance(tv, Type3) or tv in type_var_map + ] + + print('self.func_arg.args', self.func_arg.args) + print('exp_type_arg_list', exp_type_arg_list) + + if len(exp_type_arg_list) != len(self.func_arg.args): + return RequireTypeSubstitutes() + + return [ + SameTypeConstraint( + typ, + prelude.function(*exp_type_arg_list), + comment=self.comment, + ) + ] + + def human_readable(self) -> HumanReadableRet: + return ( + '{type3} == {func_arg}', + { + 'type3': self.type3, + 'func_arg': self.func_arg.name, + }, + ) + class TupleMatchConstraint(ConstraintBase): __slots__ = ('exp_type', 'args', ) @@ -264,14 +336,10 @@ class MustImplementTypeClassConstraint(ConstraintBase): __slots__ = ('context', 'type_class3', 'types', ) context: Context - type_class3: Union[str, Type3Class] + type_class3: Type3Class types: list[Type3OrPlaceholder] - DATA = { - 'bytes': {'Foldable'}, - } - - def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None: + def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None: super().__init__(comment=comment) self.context = context @@ -299,13 +367,9 @@ class MustImplementTypeClassConstraint(ConstraintBase): assert len(typ_list) == len(self.types) - if isinstance(self.type_class3, Type3Class): - key = (self.type_class3, tuple(typ_list), ) - if key in self.context.type_class_instances_existing: - return None - else: - if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()): - return None + key = (self.type_class3, tuple(typ_list), ) + if key in self.context.type_class_instances_existing: + return None typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name typ_name_list = ' '.join(x.name for x in typ_list) diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index 01c1549..7f174d5 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -12,12 +12,14 @@ from .constraints import ( Context, LiteralFitsConstraint, MustImplementTypeClassConstraint, + SameFunctionArgumentConstraint, SameTypeArgumentConstraint, SameTypeConstraint, TupleMatchConstraint, ) from .functions import ( Constraint_TypeClassInstanceExists, + FunctionArgument, FunctionSignature, TypeVariable, TypeVariableApplication_Unary, @@ -111,6 +113,33 @@ def _expression_function_call( raise NotImplementedError(constraint) + func_var_map = { + x: PlaceholderForType([]) + for x in signature.args + if isinstance(x, FunctionArgument) + } + + # If some of the function arguments are functions, + # we need to deal with those separately. + for sig_arg in signature.args: + if not isinstance(sig_arg, FunctionArgument): + continue + + # Ensure that for all type variables in the function + # there are also type variables available + for func_arg in sig_arg.args: + if isinstance(func_arg, Type3): + continue + + type_var_map.setdefault(func_arg, PlaceholderForType([])) + + yield SameFunctionArgumentConstraint( + func_var_map[sig_arg], + sig_arg, + type_var_map, + comment=f'Ensure `{sig_arg.name}` matches in {signature}', + ) + # If some of the function arguments are type constructors, # we need to deal with those separately. # That is, given `foo :: t a -> a` we need to ensure @@ -120,6 +149,9 @@ def _expression_function_call( # Not a type variable at all continue + if isinstance(sig_arg, FunctionArgument): + continue + if sig_arg.application.constructor is None: # Not a type variable for a type constructor continue @@ -150,9 +182,20 @@ def _expression_function_call( yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment) continue + if isinstance(sig_part, FunctionArgument): + yield SameTypeConstraint(func_var_map[sig_part], arg_placeholders[arg_expr], comment=comment) + continue + raise NotImplementedError(sig_part) return +def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator: + yield SameTypeConstraint( + prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3), + phft, + comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})', + ) + def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator: if isinstance(inp, ourlang.Constant): yield from constant(ctx, inp, phft) @@ -171,6 +214,10 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) yield from expression_function_call(ctx, inp, phft) return + if isinstance(inp, ourlang.FunctionReference): + yield from expression_function_reference(ctx, inp, phft) + return + if isinstance(inp, ourlang.TupleInstantiation): r_type = [] for arg in inp.elements: @@ -209,19 +256,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') return - if isinstance(inp, ourlang.Fold): - base_phft = PlaceholderForType([inp.base]) - iter_phft = PlaceholderForType([inp.iter]) - - yield from expression(ctx, inp.base, base_phft) - yield from expression(ctx, inp.iter, iter_phft) - - yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft, - comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b') - yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft]) - - return - raise NotImplementedError(expression, inp) def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: diff --git a/phasm/type3/functions.py b/phasm/type3/functions.py index 1b91948..a49f80b 100644 --- a/phasm/type3/functions.py +++ b/phasm/type3/functions.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Any, Hashable, Iterable, List, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Hashable, Iterable, List if TYPE_CHECKING: from .typeclasses import Type3Class @@ -155,15 +157,29 @@ class TypeVariableContext: def __repr__(self) -> str: return f'TypeVariableContext({self.constraints!r})' +class FunctionArgument: + __slots__ = ('args', 'name', ) + + args: list[Type3 | TypeVariable] + name: str + + def __init__(self, args: list[Type3 | TypeVariable]) -> None: + self.args = args + + self.name = '(' + ' -> '.join(x.name for x in args) + ')' + class FunctionSignature: __slots__ = ('context', 'args', ) context: TypeVariableContext - args: List[Union['Type3', TypeVariable]] + args: List[Type3 | TypeVariable | FunctionArgument] - def __init__(self, context: TypeVariableContext, args: Iterable[Union['Type3', TypeVariable]]) -> None: + def __init__(self, context: TypeVariableContext, args: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]) -> None: self.context = context.__copy__() - self.args = list(args) + self.args = list( + FunctionArgument(x) if isinstance(x, list) else x + for x in args + ) def __str__(self) -> str: return str(self.context) + ' -> '.join(x.name for x in self.args) diff --git a/phasm/type3/typeclasses.py b/phasm/type3/typeclasses.py index 83d87be..02c6ad0 100644 --- a/phasm/type3/typeclasses.py +++ b/phasm/type3/typeclasses.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, Mapping, Optional, Union +from typing import Dict, Iterable, List, Mapping, Optional from .functions import ( Constraint_TypeClassInstanceExists, @@ -42,8 +42,8 @@ class Type3Class: self, name: str, args: Type3ClassArgs, - methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]], - operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]], + methods: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]], + operators: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]], inherited_classes: Optional[List['Type3Class']] = None, additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None, ) -> None: @@ -71,19 +71,23 @@ class Type3Class: return self.name def _create_signature( - method_arg_list: Iterable[Type3 | TypeVariable], + method_arg_list: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]], type_class3: Type3Class, ) -> FunctionSignature: context = TypeVariableContext() if not isinstance(type_class3.args[0], TypeConstructorVariable): context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args)) - signature_args: list[Type3 | TypeVariable] = [] + signature_args: list[Type3 | TypeVariable | list[Type3 | TypeVariable]] = [] for method_arg in method_arg_list: if isinstance(method_arg, Type3): signature_args.append(method_arg) continue + if isinstance(method_arg, list): + signature_args.append(method_arg) + continue + if isinstance(method_arg, TypeVariable): type_constructor = method_arg.application.constructor if type_constructor is None: diff --git a/phasm/type3/types.py b/phasm/type3/types.py index 7a1b835..a949fb5 100644 --- a/phasm/type3/types.py +++ b/phasm/type3/types.py @@ -239,6 +239,10 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar): def make_name(self, key: Tuple[Type3, ...]) -> str: return '(' + ', '.join(x.name for x in key) + ', )' +class TypeConstructor_Function(TypeConstructor_TypeStar): + def make_name(self, key: Tuple[Type3, ...]) -> str: + return '(' + ' -> '.join(x.name for x in key) + ')' + class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]): """ Constructs struct types diff --git a/tests/integration/test_lang/test_builtins.py b/tests/integration/test_lang/test_builtins.py deleted file mode 100644 index 13400fd..0000000 --- a/tests/integration/test_lang/test_builtins.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from ..helpers import Suite - - -@pytest.mark.integration_test -def test_foldl_1(): - code_py = """ -def u8_or(l: u8, r: u8) -> u8: - return l | r - -@exported -def testEntry(b: bytes) -> u8: - return foldl(u8_or, 128, b) -""" - suite = Suite(code_py) - - result = suite.run_code(b'') - assert 128 == result.returned_value - - result = suite.run_code(b'\x80') - assert 128 == result.returned_value - - result = suite.run_code(b'\x80\x40') - assert 192 == result.returned_value - - result = suite.run_code(b'\x80\x40\x20\x10') - assert 240 == result.returned_value - - result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01') - assert 255 == result.returned_value - -@pytest.mark.integration_test -def test_foldl_2(): - code_py = """ -def xor(l: u8, r: u8) -> u8: - return l ^ r - -@exported -def testEntry(a: bytes, b: bytes) -> u8: - return foldl(xor, 0, a) ^ foldl(xor, 0, b) -""" - suite = Suite(code_py) - - result = suite.run_code(b'\x55\x0F', b'\x33\x80') - assert 233 == result.returned_value - -@pytest.mark.integration_test -def test_foldl_3(): - code_py = """ -def xor(l: u32, r: u8) -> u32: - return l ^ extend(r) - -@exported -def testEntry(a: bytes) -> u32: - return foldl(xor, 0, a) -""" - suite = Suite(code_py) - - result = suite.run_code(b'\x55\x0F\x33\x80') - assert 233 == result.returned_value diff --git a/tests/integration/test_lang/test_foldable.py b/tests/integration/test_lang/test_foldable.py index c03d6e7..38e08ce 100644 --- a/tests/integration/test_lang/test_foldable.py +++ b/tests/integration/test_lang/test_foldable.py @@ -36,6 +36,80 @@ def testEntry(x: Foo[4]) -> Foo: with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'): Suite(code_py).run_code() +@pytest.mark.integration_test +def test_foldable_foldl_size(): + code_py = """ +def u8_or(l: u8, r: u8) -> u8: + return l | r + +@exported +def testEntry(b: bytes) -> u8: + return foldl(u8_or, 128, b) +""" + suite = Suite(code_py) + + result = suite.run_code(b'') + assert 128 == result.returned_value + + result = suite.run_code(b'\x80') + assert 128 == result.returned_value + + result = suite.run_code(b'\x80\x40') + assert 192 == result.returned_value + + result = suite.run_code(b'\x80\x40\x20\x10') + assert 240 == result.returned_value + + result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01') + assert 255 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('direction, exp_result', [ + ('foldl', -55, ), + ('foldr', -5, ), +]) +def test_foldable_foldl_foldr(direction, exp_result): + # See https://stackoverflow.com/a/13280185 + code_py = f""" +def i32_sub(l: i32, r: i32) -> i32: + return l - r + +@exported +def testEntry(b: i32[10]) -> i32: + return {direction}(i32_sub, 0, b) +""" + suite = Suite(code_py) + + result = suite.run_code(tuple(range(1, 10))) + assert exp_result == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('in_typ', ['i8', 'i8[3]']) +def test_foldable_argument_must_be_a_function(in_typ): + code_py = f""" +@exported +def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32: + return foldl(x, y, z) +""" + + r_in_typ = in_typ.replace('[', '\\[').replace(']', '\\]') + + with pytest.raises(Type3Exception, match=f'{r_in_typ} must be a function instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_foldable_argument_must_be_right_function(): + code_py = """ +def foo(l: i32, r: i64) -> i64: + return extend(l) + r + +@exported +def testEntry(i: i64, l: i64[3]) -> i64: + return foldr(foo, i, l) +""" + + with pytest.raises(Type3Exception, match=r'\(i64 -> i64 -> i64\) must be \(i32 -> i64 -> i64\) instead'): + Suite(code_py).run_code() @pytest.mark.integration_test def test_foldable_invalid_return_type():