Compare commits

...

1 Commits

Author SHA1 Message Date
Johan B.W. de Vries
2ff532467a Idea: Allow functions are return argument 2025-05-21 18:50:45 +02:00
6 changed files with 68 additions and 4 deletions

View File

@ -83,6 +83,9 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp.function, ourlang.StructConstructor): if isinstance(inp.function, ourlang.StructConstructor):
return f'{inp.function.struct_type3.name}({args})' return f'{inp.function.struct_type3.name}({args})'
if isinstance(inp.function, ourlang.FunctionCall):
return f'{expression(inp.function)}({args})'
return f'{inp.function.name}({args})' return f'{inp.function.name}({args})'
if isinstance(inp, ourlang.FunctionReference): if isinstance(inp, ourlang.FunctionReference):

View File

@ -330,6 +330,20 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})') wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
return return
if isinstance(inp.function, ourlang.FunctionCall):
assert inp.function.type3 is not None, TYPE3_ASSERTION_ERROR
params = [
type3(x).to_wat()
for x in inp.function.type3.application.arguments
]
result = params.pop()
expression(wgn, mod, inp.function)
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
return
wgn.add_statement('call', '${}'.format(inp.function.name)) wgn.add_statement('call', '${}'.format(inp.function.name))
return return

View File

@ -151,10 +151,10 @@ class FunctionCall(Expression):
""" """
__slots__ = ('function', 'arguments', ) __slots__ = ('function', 'arguments', )
function: Union['Function', 'FunctionParam', Type3ClassMethod] function: Union['Function', 'FunctionParam', Type3ClassMethod, 'FunctionCall']
arguments: List[Expression] arguments: List[Expression]
def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod]) -> None: def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod, 'FunctionCall']) -> None:
super().__init__() super().__init__()
self.function = function self.function = function

View File

@ -470,6 +470,18 @@ class OurVisitor:
if node.keywords: if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet? _raise_static_error(node, 'Keyword calling not supported') # Yet?
if isinstance(node.func, ast.Call):
# e.g. foo(0)(1, 2, 3)
result = FunctionCall(
self.visit_Module_FunctionDef_Call(module, function, our_locals, node.func)
)
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
for arg_expr in node.args
)
return result
if not isinstance(node.func, ast.Name): if not isinstance(node.func, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {node.func}') raise NotImplementedError(f'Calling methods that are not a name {node.func}')
if not isinstance(node.func.ctx, ast.Load): if not isinstance(node.func.ctx, ast.Load):

View File

@ -57,7 +57,9 @@ def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderF
) )
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator: def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp.function, ourlang.FunctionParam): if isinstance(inp.function, (ourlang.FunctionParam, ourlang.FunctionCall, )):
assert inp.function.type3 is not None
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function) assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
signature = FunctionSignature( signature = FunctionSignature(
TypeVariableContext(), TypeVariableContext(),
@ -66,9 +68,15 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Plac
else: else:
signature = inp.function.signature signature = inp.function.signature
if isinstance(inp.function, ourlang.FunctionCall):
yield from expression_function_call(ctx, inp.function, other_phft)
func_name = inp.function.function.name + '()'
else:
func_name = inp.function.name
return _expression_function_call( return _expression_function_call(
ctx, ctx,
inp.function.name, func_name,
signature, signature,
inp.arguments, inp.arguments,
inp, inp,

View File

@ -77,6 +77,33 @@ def testEntry() -> i32:
assert 42 == result.returned_value assert 42 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_return():
code_py = """
def add(left: i32, right: i32) -> i32:
return left + right
def sub(left: i32, right: i32) -> i32:
return left - right
def get_func(i: i32) -> Callable[i32, i32]:
if i == 0:
return add
return sub
@exported
def testEntry(i: i32) -> i32:
return get_func(i)(10, 5)
"""
suite = Suite(code_py)
result = suite.run_code(0)
assert 15 == result.returned_value
result = suite.run_code(1)
assert 5 == result.returned_value
@pytest.mark.integration_test @pytest.mark.integration_test
def test_sof_wrong_argument_type(): def test_sof_wrong_argument_type():
code_py = """ code_py = """