From a72bd60de2876c53e9948807e2ffc7ce8eeb7bb7 Mon Sep 17 00:00:00 2001 From: "Johan B.W. de Vries" Date: Sat, 17 May 2025 19:43:52 +0200 Subject: [PATCH] Adds functions as passable values --- phasm/codestyle.py | 3 + phasm/compiler.py | 105 ++++++---- phasm/ourlang.py | 20 +- phasm/parser.py | 30 ++- phasm/prelude/__init__.py | 13 ++ phasm/type3/constraints.py | 74 ++++++- phasm/type3/constraintsgenerator.py | 57 +++++- phasm/type3/functions.py | 24 ++- phasm/type3/types.py | 4 + phasm/wasm.py | 5 +- .../test_lang/test_second_order_functions.py | 182 ++++++++++++++++++ .../test_typeclasses/test_foldable.py | 2 +- 12 files changed, 464 insertions(+), 55 deletions(-) create mode 100644 tests/integration/test_lang/test_second_order_functions.py diff --git a/phasm/codestyle.py b/phasm/codestyle.py index fc40868..3b7e6f4 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -85,6 +85,9 @@ def expression(inp: ourlang.Expression) -> str: return f'{inp.function.name}({args})' + if isinstance(inp, ourlang.FunctionReference): + return str(inp.function.name) + if isinstance(inp, ourlang.TupleInstantiation): args = ', '.join( expression(arg) diff --git a/phasm/compiler.py b/phasm/compiler.py index d098499..0d4293c 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -17,6 +17,7 @@ from .type3.types import ( TypeApplication_Struct, TypeApplication_TypeInt, TypeApplication_TypeStar, + TypeConstructor_Function, TypeConstructor_StaticArray, TypeConstructor_Tuple, ) @@ -100,7 +101,7 @@ def type3(inp: Type3) -> wasm.WasmType: raise NotImplementedError(type3, inp) -def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> None: +def tuple_instantiation(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.TupleInstantiation) -> None: """ Compile: Instantiation (allocation) of a tuple """ @@ -150,7 +151,7 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> wgn.add_statement('nop', comment='PRE') wgn.local.get(tmp_var) - expression(wgn, element) + expression(wgn, mod, element) wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset)) wgn.add_statement('nop', comment='POST') @@ -160,29 +161,29 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> wgn.local.get(tmp_var) def expression_subscript_bytes( - attrs: tuple[WasmGenerator, ourlang.Subscript], + attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], ) -> None: - wgn, inp = attrs + wgn, mod, inp = attrs - expression(wgn, inp.varref) - expression(wgn, inp.index) + expression(wgn, mod, inp.varref) + expression(wgn, mod, inp.index) wgn.call(stdlib_types.__subscript_bytes__) def expression_subscript_static_array( - attrs: tuple[WasmGenerator, ourlang.Subscript], + attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], args: tuple[Type3, IntType3], ) -> None: - wgn, inp = attrs + wgn, mod, inp = attrs el_type, el_len = args # OPTIMIZE: If index is a constant, we can use offset instead of multiply # and we don't need to do the out of bounds check - expression(wgn, inp.varref) + expression(wgn, mod, inp.varref) tmp_var = wgn.temp_var_i32('index') - expression(wgn, inp.index) + expression(wgn, mod, inp.index) wgn.local.tee(tmp_var) # Out of bounds check based on el_len.value @@ -201,10 +202,10 @@ def expression_subscript_static_array( wgn.add_statement(f'{mtyp}.load') def expression_subscript_tuple( - attrs: tuple[WasmGenerator, ourlang.Subscript], + attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], args: tuple[Type3, ...], ) -> None: - wgn, inp = attrs + wgn, mod, inp = attrs assert isinstance(inp.index, ourlang.ConstantPrimitive) assert isinstance(inp.index.value, int) @@ -217,7 +218,7 @@ def expression_subscript_tuple( el_type = args[inp.index.value] assert el_type is not None, TYPE3_ASSERTION_ERROR - expression(wgn, inp.varref) + expression(wgn, mod, inp.varref) if (prelude.InternalPassAsPointer, (el_type, )) in prelude.PRELUDE_TYPE_CLASS_INSTANCES_EXISTING: mtyp = 'i32' @@ -226,12 +227,12 @@ def expression_subscript_tuple( wgn.add_statement(f'{mtyp}.load', f'offset={offset}') -SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Subscript], None]() +SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], None]() SUBSCRIPT_ROUTER.add_n(prelude.bytes_, expression_subscript_bytes) SUBSCRIPT_ROUTER.add(prelude.static_array, expression_subscript_static_array) SUBSCRIPT_ROUTER.add(prelude.tuple_, expression_subscript_tuple) -def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: +def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) -> None: """ Compile: Any expression """ @@ -291,14 +292,14 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.i32.const(address) return - expression(wgn, inp.variable.constant) + expression(wgn, mod, inp.variable.constant) return raise NotImplementedError(expression, inp.variable) if isinstance(inp, ourlang.BinaryOp): - expression(wgn, inp.left) - expression(wgn, inp.right) + expression(wgn, mod, inp.left) + expression(wgn, mod, inp.right) type_var_map: dict[TypeVariable, Type3] = {} @@ -321,7 +322,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: if isinstance(inp, ourlang.FunctionCall): for arg in inp.arguments: - expression(wgn, arg) + expression(wgn, mod, arg) if isinstance(inp.function, Type3ClassMethod): # FIXME: Duplicate code with BinaryOp @@ -347,18 +348,42 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: raise NotImplementedError(str(inp.function), type_var_map) return + if isinstance(inp.function, ourlang.FunctionParam): + assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function) + + params = [ + type3(x).to_wat() + for x in inp.function.type3.application.arguments + ] + + result = params.pop() + + params_str = ' '.join(params) + wgn.add_statement('local.get', '${}'.format(inp.function.name)) + wgn.add_statement(f'call_indirect (param {params_str}) (result {result})') + return + wgn.add_statement('call', '${}'.format(inp.function.name)) return + if isinstance(inp, ourlang.FunctionReference): + idx = mod.functions_table.get(inp.function) + if idx is None: + idx = len(mod.functions_table) + mod.functions_table[inp.function] = idx + + wgn.add_statement('i32.const', str(idx), comment=inp.function.name) + return + if isinstance(inp, ourlang.TupleInstantiation): - tuple_instantiation(wgn, inp) + tuple_instantiation(wgn, mod, inp) return if isinstance(inp, ourlang.Subscript): assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR # Type checker guarantees we don't get routing errors - SUBSCRIPT_ROUTER((wgn, inp, ), inp.varref.type3) + SUBSCRIPT_ROUTER((wgn, mod, inp, ), inp.varref.type3) return if isinstance(inp, ourlang.AccessStructMember): @@ -370,19 +395,19 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: mtyp = LOAD_STORE_TYPE_MAP[member_type.name] - expression(wgn, inp.varref) + expression(wgn, mod, inp.varref) wgn.add_statement(f'{mtyp}.load', 'offset=' + str(calculate_member_offset( inp.struct_type3.name, inp.struct_type3.application.arguments, inp.member ))) return if isinstance(inp, ourlang.Fold): - expression_fold(wgn, inp) + expression_fold(wgn, mod, inp) return raise NotImplementedError(expression, inp) -def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: +def expression_fold(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Fold) -> None: """ Compile: Fold expression """ @@ -399,11 +424,11 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: len_var = wgn.temp_var_i32('fold_i32_len') wgn.add_statement('nop', comment='acu = base') - expression(wgn, inp.base) + expression(wgn, mod, inp.base) wgn.local.set(acu_var) wgn.add_statement('nop', comment='adr = adr(iter)') - expression(wgn, inp.iter) + expression(wgn, mod, inp.iter) wgn.local.set(adr_var) wgn.add_statement('nop', comment='len = len(iter)') @@ -460,21 +485,21 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: # return acu wgn.local.get(acu_var) -def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None: +def statement_return(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementReturn) -> None: """ Compile: Return statement """ - expression(wgn, inp.value) + expression(wgn, mod, inp.value) wgn.return_() -def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None: +def statement_if(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementIf) -> None: """ Compile: If statement """ - expression(wgn, inp.test) + expression(wgn, mod, inp.test) with wgn.if_(): for stat in inp.statements: - statement(wgn, stat) + statement(wgn, mod, stat) if inp.else_statements: raise NotImplementedError @@ -482,16 +507,16 @@ def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None: # for stat in inp.else_statements: # statement(wgn, stat) -def statement(wgn: WasmGenerator, inp: ourlang.Statement) -> None: +def statement(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Statement) -> None: """ Compile: any statement """ if isinstance(inp, ourlang.StatementReturn): - statement_return(wgn, inp) + statement_return(wgn, mod, inp) return if isinstance(inp, ourlang.StatementIf): - statement_if(wgn, inp) + statement_if(wgn, mod, inp) return if isinstance(inp, ourlang.StatementPass): @@ -522,7 +547,7 @@ def import_(inp: ourlang.Function) -> wasm.Import: type3(inp.returns_type3) ) -def function(inp: ourlang.Function) -> wasm.Function: +def function(mod: ourlang.Module, inp: ourlang.Function) -> wasm.Function: """ Compile: function """ @@ -534,7 +559,7 @@ def function(inp: ourlang.Function) -> wasm.Function: _generate_struct_constructor(wgn, inp) else: for stat in inp.statements: - statement(wgn, stat) + statement(wgn, mod, stat) return wasm.Function( inp.name, @@ -739,11 +764,17 @@ def module(inp: ourlang.Module) -> wasm.Module: stdlib_types.__u8_rotl__, stdlib_types.__u8_rotr__, ] + [ - function(x) + function(inp, x) for x in inp.functions.values() if not x.imported ] + # Do this after rendering the functions since that's what populates the tables + result.table = { + v: k.name + for k, v in inp.functions_table.items() + } + return result def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: diff --git a/phasm/ourlang.py b/phasm/ourlang.py index df97b23..d25097d 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -152,15 +152,27 @@ class FunctionCall(Expression): """ __slots__ = ('function', 'arguments', ) - function: Union['Function', Type3ClassMethod] + function: Union['Function', 'FunctionParam', Type3ClassMethod] arguments: List[Expression] - def __init__(self, function: Union['Function', Type3ClassMethod]) -> None: + def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod]) -> None: super().__init__() self.function = function self.arguments = [] +class FunctionReference(Expression): + """ + An function reference expression within a statement + """ + __slots__ = ('function', ) + + function: 'Function' + + def __init__(self, function: 'Function') -> None: + super().__init__() + self.function = function + class TupleInstantiation(Expression): """ Instantiation a tuple @@ -397,7 +409,7 @@ class Module: """ A module is a file and consists of functions """ - __slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', ) + __slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', 'functions_table', ) data: ModuleData types: dict[str, Type3] @@ -405,6 +417,7 @@ class Module: constant_defs: Dict[str, ModuleConstantDef] functions: Dict[str, Function] operators: Dict[str, Type3ClassMethod] + functions_table: dict[Function, int] def __init__(self) -> None: self.data = ModuleData() @@ -413,3 +426,4 @@ class Module: self.constant_defs = {} self.functions = {} self.operators = {} + self.functions_table = {} diff --git a/phasm/parser.py b/phasm/parser.py index 9944dc2..1fc1139 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -18,6 +18,7 @@ from .ourlang import ( Function, FunctionCall, FunctionParam, + FunctionReference, Module, ModuleConstantDef, ModuleDataBlock, @@ -446,6 +447,10 @@ class OurVisitor: cdef = module.constant_defs[node.id] return VariableReference(cdef) + if node.id in module.functions: + fun = module.functions[node.id] + return FunctionReference(fun) + _raise_static_error(node, f'Undefined variable {node.id}') if isinstance(node, ast.Tuple): @@ -471,7 +476,7 @@ class OurVisitor: if not isinstance(node.func.ctx, ast.Load): _raise_static_error(node, 'Must be load context') - func: Union[Function, Type3ClassMethod] + func: Union[Function, FunctionParam, Type3ClassMethod] if node.func.id in PRELUDE_METHODS: func = PRELUDE_METHODS[node.func.id] @@ -497,17 +502,14 @@ class OurVisitor: self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), ) + elif node.func.id in our_locals: + func = our_locals[node.func.id] else: if node.func.id not in module.functions: _raise_static_error(node, 'Call to undefined function') func = module.functions[node.func.id] - exp_arg_count = len(func.signature.args) - 1 - - if exp_arg_count != len(node.args): - _raise_static_error(node, f'Function {node.func.id} requires {exp_arg_count} arguments but {len(node.args)} are given') - result = FunctionCall(func) result.arguments.extend( self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) @@ -635,6 +637,22 @@ class OurVisitor: _raise_static_error(node, f'Unrecognized type {node.id}') if isinstance(node, ast.Subscript): + if isinstance(node.value, ast.Name) and node.value.id == 'Callable': + func_arg_types: list[ast.expr] + + if isinstance(node.slice, ast.Name): + func_arg_types = [node.slice] + elif isinstance(node.slice, ast.Tuple): + func_arg_types = node.slice.elts + else: + _raise_static_error(node, 'Must subscript using a list of types') + + # Function type + return prelude.function(*[ + self.visit_type(module, e) + for e in func_arg_types + ]) + if isinstance(node.slice, ast.Slice): _raise_static_error(node, 'Must subscript using an index') if not isinstance(node.slice, ast.Constant): diff --git a/phasm/prelude/__init__.py b/phasm/prelude/__init__.py index b91c31f..77eb142 100644 --- a/phasm/prelude/__init__.py +++ b/phasm/prelude/__init__.py @@ -20,6 +20,7 @@ from ..type3.types import ( Type3, TypeApplication_Nullary, TypeConstructor_Base, + TypeConstructor_Function, TypeConstructor_StaticArray, TypeConstructor_Struct, TypeConstructor_Tuple, @@ -186,6 +187,18 @@ It should be applied with zero or more arguments. It has a compile time determined length, and each argument can be different. """ +def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None: + # Not really a pointer; but still a i32 + # (It's actually a table lookup) + instance_type_class(InternalPassAsPointer, typ) + +function = TypeConstructor_Function('function', on_create=fn_on_create) +""" +This is a function. + +It should be applied with one or more arguments. The last argument is the 'return' type. +""" + def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None: instance_type_class(InternalPassAsPointer, typ) diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index 4a9541b..0eeb6fb 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -6,6 +6,7 @@ These need to be resolved before the program can be compiled. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from .. import ourlang, prelude +from .functions import FunctionArgument, TypeVariable from .placeholders import PlaceholderForType, Type3OrPlaceholder from .routers import NoRouteForTypeException, TypeApplicationRouter from .typeclasses import Type3Class @@ -204,10 +205,81 @@ class SameTypeArgumentConstraint(ConstraintBase): if tc_typ.application.arguments[0] == arg_typ: return None - return Error(f'{tc_typ.application.arguments[0]:s} must be {arg_typ:s} instead') + return [SameTypeConstraint( + tc_typ.application.arguments[0], + self.arg_var, + comment=self.comment, + )] raise NotImplementedError(tc_typ, arg_typ) + def human_readable(self) -> HumanReadableRet: + return ( + '{tc_var}` == {arg_var}', + { + 'tc_var': self.tc_var if self.tc_var.resolve_as is None else self.tc_var, + 'arg_var': self.arg_var if self.arg_var.resolve_as is None else self.arg_var, + }, + ) + +class SameFunctionArgumentConstraint(ConstraintBase): + __slots__ = ('type3', 'func_arg', 'type_var_map', ) + + type3: PlaceholderForType + func_arg: FunctionArgument + type_var_map: dict[TypeVariable, PlaceholderForType] + + def __init__(self, type3: PlaceholderForType, func_arg: FunctionArgument, type_var_map: dict[TypeVariable, PlaceholderForType], *, comment: str) -> None: + super().__init__(comment=comment) + + self.type3 = type3 + self.func_arg = func_arg + self.type_var_map = type_var_map + + def check(self) -> CheckResult: + if self.type3.resolve_as is None: + return RequireTypeSubstitutes() + + typ = self.type3.resolve_as + + if isinstance(typ.application, TypeApplication_Nullary): + return Error(f'{typ:s} must be a function instead') + + if not isinstance(typ.application, TypeApplication_TypeStar): + return Error(f'{typ:s} must be a function instead') + + type_var_map = { + x: y.resolve_as + for x, y in self.type_var_map.items() + if y.resolve_as is not None + } + + exp_type_arg_list = [ + tv if isinstance(tv, Type3) else type_var_map[tv] + for tv in self.func_arg.args + if isinstance(tv, Type3) or tv in type_var_map + ] + + if len(exp_type_arg_list) != len(self.func_arg.args): + return RequireTypeSubstitutes() + + return [ + SameTypeConstraint( + typ, + prelude.function(*exp_type_arg_list), + comment=self.comment, + ) + ] + + def human_readable(self) -> HumanReadableRet: + return ( + '{type3} == {func_arg}', + { + 'type3': self.type3, + 'func_arg': self.func_arg.name, + }, + ) + class TupleMatchConstraint(ConstraintBase): __slots__ = ('exp_type', 'args', ) diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index 01c1549..1a4771a 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -12,18 +12,21 @@ from .constraints import ( Context, LiteralFitsConstraint, MustImplementTypeClassConstraint, + SameFunctionArgumentConstraint, SameTypeArgumentConstraint, SameTypeConstraint, TupleMatchConstraint, ) from .functions import ( Constraint_TypeClassInstanceExists, + FunctionArgument, FunctionSignature, TypeVariable, TypeVariableApplication_Unary, + TypeVariableContext, ) from .placeholders import PlaceholderForType -from .types import Type3, TypeApplication_Struct +from .types import Type3, TypeApplication_Struct, TypeConstructor_Function ConstraintGenerator = Generator[ConstraintBase, None, None] @@ -54,15 +57,31 @@ def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderF ) def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator: + if isinstance(inp.function, ourlang.FunctionParam): + assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function) + signature = FunctionSignature( + TypeVariableContext(), + inp.function.type3.application.arguments, + ) + else: + signature = inp.function.signature + return _expression_function_call( ctx, inp.function.name, - inp.function.signature, + signature, inp.arguments, inp, phft, ) +def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator: + yield SameTypeConstraint( + prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3), + phft, + comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})', + ) + def _expression_function_call( ctx: Context, func_name: str, @@ -111,6 +130,33 @@ def _expression_function_call( raise NotImplementedError(constraint) + func_var_map = { + x: PlaceholderForType([]) + for x in signature.args + if isinstance(x, FunctionArgument) + } + + # If some of the function arguments are functions, + # we need to deal with those separately. + for sig_arg in signature.args: + if not isinstance(sig_arg, FunctionArgument): + continue + + # Ensure that for all type variables in the function + # there are also type variables available + for func_arg in sig_arg.args: + if isinstance(func_arg, Type3): + continue + + type_var_map.setdefault(func_arg, PlaceholderForType([])) + + yield SameFunctionArgumentConstraint( + func_var_map[sig_arg], + sig_arg, + type_var_map, + comment=f'Ensure `{sig_arg.name}` matches in {signature}', + ) + # If some of the function arguments are type constructors, # we need to deal with those separately. # That is, given `foo :: t a -> a` we need to ensure @@ -120,6 +166,9 @@ def _expression_function_call( # Not a type variable at all continue + if isinstance(sig_arg, FunctionArgument): + continue + if sig_arg.application.constructor is None: # Not a type variable for a type constructor continue @@ -171,6 +220,10 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) yield from expression_function_call(ctx, inp, phft) return + if isinstance(inp, ourlang.FunctionReference): + yield from expression_function_reference(ctx, inp, phft) + return + if isinstance(inp, ourlang.TupleInstantiation): r_type = [] for arg in inp.elements: diff --git a/phasm/type3/functions.py b/phasm/type3/functions.py index 1b91948..a49f80b 100644 --- a/phasm/type3/functions.py +++ b/phasm/type3/functions.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Any, Hashable, Iterable, List, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Hashable, Iterable, List if TYPE_CHECKING: from .typeclasses import Type3Class @@ -155,15 +157,29 @@ class TypeVariableContext: def __repr__(self) -> str: return f'TypeVariableContext({self.constraints!r})' +class FunctionArgument: + __slots__ = ('args', 'name', ) + + args: list[Type3 | TypeVariable] + name: str + + def __init__(self, args: list[Type3 | TypeVariable]) -> None: + self.args = args + + self.name = '(' + ' -> '.join(x.name for x in args) + ')' + class FunctionSignature: __slots__ = ('context', 'args', ) context: TypeVariableContext - args: List[Union['Type3', TypeVariable]] + args: List[Type3 | TypeVariable | FunctionArgument] - def __init__(self, context: TypeVariableContext, args: Iterable[Union['Type3', TypeVariable]]) -> None: + def __init__(self, context: TypeVariableContext, args: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]) -> None: self.context = context.__copy__() - self.args = list(args) + self.args = list( + FunctionArgument(x) if isinstance(x, list) else x + for x in args + ) def __str__(self) -> str: return str(self.context) + ' -> '.join(x.name for x in self.args) diff --git a/phasm/type3/types.py b/phasm/type3/types.py index 7a1b835..489fd01 100644 --- a/phasm/type3/types.py +++ b/phasm/type3/types.py @@ -239,6 +239,10 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar): def make_name(self, key: Tuple[Type3, ...]) -> str: return '(' + ', '.join(x.name for x in key) + ', )' +class TypeConstructor_Function(TypeConstructor_TypeStar): + def make_name(self, key: Tuple[Type3, ...]) -> str: + return 'Callable[' + ', '.join(x.name for x in key) + ']' + class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]): """ Constructs struct types diff --git a/phasm/wasm.py b/phasm/wasm.py index c473e63..2750ee3 100644 --- a/phasm/wasm.py +++ b/phasm/wasm.py @@ -186,6 +186,7 @@ class Module(WatSerializable): """ def __init__(self) -> None: self.imports: List[Import] = [] + self.table: dict[int, str] = {} self.functions: List[Function] = [] self.memory = ModuleMemory() @@ -193,8 +194,10 @@ class Module(WatSerializable): """ Generates the text version """ - return '(module\n {}\n {}\n {})\n'.format( + return '(module\n {}\n {}\n {}\n {}\n {})\n'.format( '\n '.join(x.to_wat() for x in self.imports), + f'(table {len(self.table)} funcref)', + '\n '.join(f'(elem (i32.const {k}) ${v})' for k, v in self.table.items()), self.memory.to_wat(), '\n '.join(x.to_wat() for x in self.functions), ) diff --git a/tests/integration/test_lang/test_second_order_functions.py b/tests/integration/test_lang/test_second_order_functions.py new file mode 100644 index 0000000..395d804 --- /dev/null +++ b/tests/integration/test_lang/test_second_order_functions.py @@ -0,0 +1,182 @@ +import pytest + +from phasm.type3.entry import Type3Exception + +from ..helpers import Suite + + +@pytest.mark.integration_test +def test_sof_in_code_0_arg(): + code_py = """ +def thirteen() -> i32: + return 13 + +def action(applicable: Callable[i32]) -> i32: + return applicable() + +@exported +def testEntry() -> i32: + return action(thirteen) +""" + + result = Suite(code_py).run_code() + + assert 13 == result.returned_value + +@pytest.mark.integration_test +def test_sof_in_code_1_arg(): + code_py = """ +def double(left: i32) -> i32: + return left * 2 + +def action(applicable: Callable[i32, i32], left: i32) -> i32: + return applicable(left) + +@exported +def testEntry() -> i32: + return action(double, 13) +""" + + result = Suite(code_py).run_code() + + assert 26 == result.returned_value + +@pytest.mark.integration_test +def test_sof_in_code_2_arg(): + code_py = """ +def add(left: i32, right: i32) -> i32: + return left + right + +def action(applicable: Callable[i32, i32, i32], left: i32, right: i32) -> i32: + return applicable(left, right) + +@exported +def testEntry() -> i32: + return action(add, 13, 14) +""" + + result = Suite(code_py).run_code() + + assert 27 == result.returned_value + +@pytest.mark.integration_test +def test_sof_in_code_3_arg(): + code_py = """ +def add(left: i32, mid: i32, right: i32) -> i32: + return left + mid + right + +def action(applicable: Callable[i32, i32, i32, i32], left: i32, mid: i32, right: i32) -> i32: + return applicable(left, mid, right) + +@exported +def testEntry() -> i32: + return action(add, 13, 14, 15) +""" + + result = Suite(code_py).run_code() + + assert 42 == result.returned_value + +@pytest.mark.integration_test +def test_sof_wrong_argument_type(): + code_py = """ +def double(left: f32) -> f32: + return left * 2 + +def action(applicable: Callable[i32, i32], left: i32) -> i32: + return applicable(left) + +@exported +def testEntry() -> i32: + return action(double, 13) +""" + + with pytest.raises(Type3Exception, match=r'Callable\[f32, f32\] must be Callable\[i32, i32\] instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_sof_wrong_return(): + code_py = """ +def double(left: i32) -> i32: + return left * 2 + +def action(applicable: Callable[i32, i32], left: i32) -> f32: + return applicable(left) + +@exported +def testEntry() -> i32: + return action(double, 13) +""" + + with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.skip('FIXME: Probably have the remainder be the a function type') +def test_sof_wrong_not_enough_args_call(): + code_py = """ +def add(left: i32, right: i32) -> i32: + return left + right + +def action(applicable: Callable[i32, i32, i32], left: i32) -> i32: + return applicable(left) + +@exported +def testEntry() -> i32: + return action(add, 13) +""" + + with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_sof_wrong_not_enough_args_refere(): + code_py = """ +def double(left: i32) -> i32: + return left * 2 + +def action(applicable: Callable[i32, i32, i32], left: i32, right: i32) -> i32: + return applicable(left, right) + +@exported +def testEntry() -> i32: + return action(double, 13, 14) +""" + + with pytest.raises(Type3Exception, match=r'Callable\[i32, i32\] must be Callable\[i32, i32, i32\] instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +@pytest.mark.skip('FIXME: Probably have the remainder be the a function type') +def test_sof_wrong_too_many_args_call(): + code_py = """ +def thirteen() -> i32: + return 13 + +def action(applicable: Callable[i32], left: i32) -> i32: + return applicable(left) + +@exported +def testEntry() -> i32: + return action(thirteen, 13) +""" + + with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_sof_wrong_too_many_args_refere(): + code_py = """ +def double(left: i32) -> i32: + return left * 2 + +def action(applicable: Callable[i32]) -> i32: + return applicable() + +@exported +def testEntry() -> i32: + return action(double) +""" + + with pytest.raises(Type3Exception, match=r'Callable\[i32, i32\] must be Callable\[i32\] instead'): + Suite(code_py).run_code() diff --git a/tests/integration/test_typeclasses/test_foldable.py b/tests/integration/test_typeclasses/test_foldable.py index c03d6e7..34320b0 100644 --- a/tests/integration/test_typeclasses/test_foldable.py +++ b/tests/integration/test_typeclasses/test_foldable.py @@ -45,7 +45,7 @@ def testEntry(x: i32[5]) -> f64: return sum(x) """ - with pytest.raises(Type3Exception, match='i32 must be f64 instead'): + with pytest.raises(Type3Exception, match='f64 must be i32 instead'): Suite(code_py).run_code((4, 5, 6, 7, 8, )) @pytest.mark.integration_test