diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 940ecfa..ba26778 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -83,6 +83,9 @@ def expression(inp: ourlang.Expression) -> str: if isinstance(inp.function, ourlang.StructConstructor): return f'{inp.function.struct_type3.name}({args})' + if isinstance(inp.function, ourlang.FunctionCall): + return f'{expression(inp.function)}({args})' + return f'{inp.function.name}({args})' if isinstance(inp, ourlang.FunctionReference): diff --git a/phasm/compiler.py b/phasm/compiler.py index 4eaf1b2..f046f73 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -330,6 +330,20 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) wgn.add_statement(f'call_indirect (param {params_str}) (result {result})') return + if isinstance(inp.function, ourlang.FunctionCall): + assert inp.function.type3 is not None, TYPE3_ASSERTION_ERROR + + params = [ + type3(x).to_wat() + for x in inp.function.type3.application.arguments + ] + + result = params.pop() + + expression(wgn, mod, inp.function) + wgn.add_statement(f'call_indirect (param {params_str}) (result {result})') + return + wgn.add_statement('call', '${}'.format(inp.function.name)) return diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 2bd3cac..db3cf8b 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -151,10 +151,10 @@ class FunctionCall(Expression): """ __slots__ = ('function', 'arguments', ) - function: Union['Function', 'FunctionParam', Type3ClassMethod] + function: Union['Function', 'FunctionParam', Type3ClassMethod, 'FunctionCall'] arguments: List[Expression] - def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod]) -> None: + def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod, 'FunctionCall']) -> None: super().__init__() self.function = function diff --git a/phasm/parser.py b/phasm/parser.py index e72946d..b3f9ab4 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -470,6 +470,18 @@ class OurVisitor: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? + if isinstance(node.func, ast.Call): + # e.g. foo(0)(1, 2, 3) + + result = FunctionCall( + self.visit_Module_FunctionDef_Call(module, function, our_locals, node.func) + ) + result.arguments.extend( + self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) + for arg_expr in node.args + ) + return result + if not isinstance(node.func, ast.Name): raise NotImplementedError(f'Calling methods that are not a name {node.func}') if not isinstance(node.func.ctx, ast.Load): diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index f28b88f..9b0d4fe 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -57,7 +57,9 @@ def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderF ) def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator: - if isinstance(inp.function, ourlang.FunctionParam): + if isinstance(inp.function, (ourlang.FunctionParam, ourlang.FunctionCall, )): + assert inp.function.type3 is not None + assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function) signature = FunctionSignature( TypeVariableContext(), @@ -66,9 +68,15 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Plac else: signature = inp.function.signature + if isinstance(inp.function, ourlang.FunctionCall): + yield from expression_function_call(ctx, inp.function, other_phft) + func_name = inp.function.function.name + '()' + else: + func_name = inp.function.name + return _expression_function_call( ctx, - inp.function.name, + func_name, signature, inp.arguments, inp, diff --git a/tests/integration/test_lang/test_second_order_functions.py b/tests/integration/test_lang/test_second_order_functions.py index 395d804..1cf7823 100644 --- a/tests/integration/test_lang/test_second_order_functions.py +++ b/tests/integration/test_lang/test_second_order_functions.py @@ -77,6 +77,33 @@ def testEntry() -> i32: assert 42 == result.returned_value +@pytest.mark.integration_test +def test_sof_in_return(): + code_py = """ +def add(left: i32, right: i32) -> i32: + return left + right + +def sub(left: i32, right: i32) -> i32: + return left - right + +def get_func(i: i32) -> Callable[i32, i32]: + if i == 0: + return add + return sub + +@exported +def testEntry(i: i32) -> i32: + return get_func(i)(10, 5) +""" + + suite = Suite(code_py) + + result = suite.run_code(0) + assert 15 == result.returned_value + + result = suite.run_code(1) + assert 5 == result.returned_value + @pytest.mark.integration_test def test_sof_wrong_argument_type(): code_py = """