Idea: Allow functions are return argument

This commit is contained in:
Johan B.W. de Vries 2025-05-21 18:50:45 +02:00
parent b48260ccfa
commit 2ff532467a
6 changed files with 68 additions and 4 deletions

View File

@ -83,6 +83,9 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp.function, ourlang.StructConstructor):
return f'{inp.function.struct_type3.name}({args})'
if isinstance(inp.function, ourlang.FunctionCall):
return f'{expression(inp.function)}({args})'
return f'{inp.function.name}({args})'
if isinstance(inp, ourlang.FunctionReference):

View File

@ -330,6 +330,20 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
return
if isinstance(inp.function, ourlang.FunctionCall):
assert inp.function.type3 is not None, TYPE3_ASSERTION_ERROR
params = [
type3(x).to_wat()
for x in inp.function.type3.application.arguments
]
result = params.pop()
expression(wgn, mod, inp.function)
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
return
wgn.add_statement('call', '${}'.format(inp.function.name))
return

View File

@ -151,10 +151,10 @@ class FunctionCall(Expression):
"""
__slots__ = ('function', 'arguments', )
function: Union['Function', 'FunctionParam', Type3ClassMethod]
function: Union['Function', 'FunctionParam', Type3ClassMethod, 'FunctionCall']
arguments: List[Expression]
def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod]) -> None:
def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod, 'FunctionCall']) -> None:
super().__init__()
self.function = function

View File

@ -470,6 +470,18 @@ class OurVisitor:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
if isinstance(node.func, ast.Call):
# e.g. foo(0)(1, 2, 3)
result = FunctionCall(
self.visit_Module_FunctionDef_Call(module, function, our_locals, node.func)
)
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr in node.args
)
return result
if not isinstance(node.func, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {node.func}')
if not isinstance(node.func.ctx, ast.Load):

View File

@ -57,7 +57,9 @@ def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderF
)
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp.function, ourlang.FunctionParam):
if isinstance(inp.function, (ourlang.FunctionParam, ourlang.FunctionCall, )):
assert inp.function.type3 is not None
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
signature = FunctionSignature(
TypeVariableContext(),
@ -66,9 +68,15 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Plac
else:
signature = inp.function.signature
if isinstance(inp.function, ourlang.FunctionCall):
yield from expression_function_call(ctx, inp.function, other_phft)
func_name = inp.function.function.name + '()'
else:
func_name = inp.function.name
return _expression_function_call(
ctx,
inp.function.name,
func_name,
signature,
inp.arguments,
inp,

View File

@ -77,6 +77,33 @@ def testEntry() -> i32:
assert 42 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_return():
code_py = """
def add(left: i32, right: i32) -> i32:
return left + right
def sub(left: i32, right: i32) -> i32:
return left - right
def get_func(i: i32) -> Callable[i32, i32]:
if i == 0:
return add
return sub
@exported
def testEntry(i: i32) -> i32:
return get_func(i)(10, 5)
"""
suite = Suite(code_py)
result = suite.run_code(0)
assert 15 == result.returned_value
result = suite.run_code(1)
assert 5 == result.returned_value
@pytest.mark.integration_test
def test_sof_wrong_argument_type():
code_py = """