Compare commits
1 Commits
master
...
allow-func
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ff532467a |
@ -83,6 +83,9 @@ def expression(inp: ourlang.Expression) -> str:
|
||||
if isinstance(inp.function, ourlang.StructConstructor):
|
||||
return f'{inp.function.struct_type3.name}({args})'
|
||||
|
||||
if isinstance(inp.function, ourlang.FunctionCall):
|
||||
return f'{expression(inp.function)}({args})'
|
||||
|
||||
return f'{inp.function.name}({args})'
|
||||
|
||||
if isinstance(inp, ourlang.FunctionReference):
|
||||
|
||||
@ -330,6 +330,20 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
|
||||
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
|
||||
return
|
||||
|
||||
if isinstance(inp.function, ourlang.FunctionCall):
|
||||
assert inp.function.type3 is not None, TYPE3_ASSERTION_ERROR
|
||||
|
||||
params = [
|
||||
type3(x).to_wat()
|
||||
for x in inp.function.type3.application.arguments
|
||||
]
|
||||
|
||||
result = params.pop()
|
||||
|
||||
expression(wgn, mod, inp.function)
|
||||
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
|
||||
return
|
||||
|
||||
wgn.add_statement('call', '${}'.format(inp.function.name))
|
||||
return
|
||||
|
||||
|
||||
@ -151,10 +151,10 @@ class FunctionCall(Expression):
|
||||
"""
|
||||
__slots__ = ('function', 'arguments', )
|
||||
|
||||
function: Union['Function', 'FunctionParam', Type3ClassMethod]
|
||||
function: Union['Function', 'FunctionParam', Type3ClassMethod, 'FunctionCall']
|
||||
arguments: List[Expression]
|
||||
|
||||
def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod]) -> None:
|
||||
def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod, 'FunctionCall']) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.function = function
|
||||
|
||||
@ -470,6 +470,18 @@ class OurVisitor:
|
||||
if node.keywords:
|
||||
_raise_static_error(node, 'Keyword calling not supported') # Yet?
|
||||
|
||||
if isinstance(node.func, ast.Call):
|
||||
# e.g. foo(0)(1, 2, 3)
|
||||
|
||||
result = FunctionCall(
|
||||
self.visit_Module_FunctionDef_Call(module, function, our_locals, node.func)
|
||||
)
|
||||
result.arguments.extend(
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
|
||||
for arg_expr in node.args
|
||||
)
|
||||
return result
|
||||
|
||||
if not isinstance(node.func, ast.Name):
|
||||
raise NotImplementedError(f'Calling methods that are not a name {node.func}')
|
||||
if not isinstance(node.func.ctx, ast.Load):
|
||||
|
||||
@ -57,7 +57,9 @@ def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderF
|
||||
)
|
||||
|
||||
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator:
|
||||
if isinstance(inp.function, ourlang.FunctionParam):
|
||||
if isinstance(inp.function, (ourlang.FunctionParam, ourlang.FunctionCall, )):
|
||||
assert inp.function.type3 is not None
|
||||
|
||||
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
|
||||
signature = FunctionSignature(
|
||||
TypeVariableContext(),
|
||||
@ -66,9 +68,15 @@ def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: Plac
|
||||
else:
|
||||
signature = inp.function.signature
|
||||
|
||||
if isinstance(inp.function, ourlang.FunctionCall):
|
||||
yield from expression_function_call(ctx, inp.function, other_phft)
|
||||
func_name = inp.function.function.name + '()'
|
||||
else:
|
||||
func_name = inp.function.name
|
||||
|
||||
return _expression_function_call(
|
||||
ctx,
|
||||
inp.function.name,
|
||||
func_name,
|
||||
signature,
|
||||
inp.arguments,
|
||||
inp,
|
||||
|
||||
@ -77,6 +77,33 @@ def testEntry() -> i32:
|
||||
|
||||
assert 42 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_in_return():
|
||||
code_py = """
|
||||
def add(left: i32, right: i32) -> i32:
|
||||
return left + right
|
||||
|
||||
def sub(left: i32, right: i32) -> i32:
|
||||
return left - right
|
||||
|
||||
def get_func(i: i32) -> Callable[i32, i32]:
|
||||
if i == 0:
|
||||
return add
|
||||
return sub
|
||||
|
||||
@exported
|
||||
def testEntry(i: i32) -> i32:
|
||||
return get_func(i)(10, 5)
|
||||
"""
|
||||
|
||||
suite = Suite(code_py)
|
||||
|
||||
result = suite.run_code(0)
|
||||
assert 15 == result.returned_value
|
||||
|
||||
result = suite.run_code(1)
|
||||
assert 5 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_sof_wrong_argument_type():
|
||||
code_py = """
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user