Compare commits

..

3 Commits

Author SHA1 Message Date
Johan B.W. de Vries
b285eb9d05 Removes the special casing for foldl
Had to implement both functions as arguments and type
place holders (variables) for type constructors.

Had to implement functions as a type as well.

Still have to figure out how to pass functions around.
2025-05-17 19:50:59 +02:00
Johan B.W. de Vries
a72bd60de2 Adds functions as passable values 2025-05-17 19:43:52 +02:00
Johan B.W. de Vries
99d2b22336 Moved the typeclasse tests. Fix typeclass name. 2025-05-17 18:42:27 +02:00
22 changed files with 252 additions and 31 deletions

View File

@ -17,6 +17,7 @@ from .type3.types import (
TypeApplication_Struct,
TypeApplication_TypeInt,
TypeApplication_TypeStar,
TypeConstructor_Function,
TypeConstructor_StaticArray,
TypeConstructor_Tuple,
)
@ -313,9 +314,9 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
type_var_map[type_var] = arg_expr.type3
continue
if isinstance(type_var, FunctionArgument):
# Fixed type, not part of the lookup requirements
continue
if isinstance(type_var, FunctionArgument):
# Fixed type, not part of the lookup requirements
continue
raise NotImplementedError(type_var, arg_expr.type3)
@ -355,6 +356,21 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
raise NotImplementedError(str(inp.function), type_var_map)
return
if isinstance(inp.function, ourlang.FunctionParam):
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
params = [
type3(x).to_wat()
for x in inp.function.type3.application.arguments
]
result = params.pop()
params_str = ' '.join(params)
wgn.add_statement('local.get', '${}'.format(inp.function.name))
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
return
wgn.add_statement('call', '${}'.format(inp.function.name))
return

View File

@ -151,10 +151,10 @@ class FunctionCall(Expression):
"""
__slots__ = ('function', 'arguments', )
function: Union['Function', Type3ClassMethod]
function: Union['Function', 'FunctionParam', Type3ClassMethod]
arguments: List[Expression]
def __init__(self, function: Union['Function', Type3ClassMethod]) -> None:
def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod]) -> None:
super().__init__()
self.function = function
@ -378,15 +378,15 @@ class Module:
"""
A module is a file and consists of functions
"""
__slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'functions_table', 'operators', )
__slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', 'functions_table', )
data: ModuleData
types: dict[str, Type3]
struct_definitions: Dict[str, StructDefinition]
constant_defs: Dict[str, ModuleConstantDef]
functions: Dict[str, Function]
functions_table: dict[Function, int]
operators: Dict[str, Type3ClassMethod]
functions_table: dict[Function, int]
def __init__(self) -> None:
self.data = ModuleData()

View File

@ -447,7 +447,8 @@ class OurVisitor:
return VariableReference(cdef)
if node.id in module.functions:
return FunctionReference(module.functions[node.id])
fun = module.functions[node.id]
return FunctionReference(fun)
_raise_static_error(node, f'Undefined variable {node.id}')
@ -474,21 +475,18 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context')
func: Union[Function, Type3ClassMethod]
func: Union[Function, FunctionParam, Type3ClassMethod]
if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id]
elif node.func.id in our_locals:
func = our_locals[node.func.id]
else:
if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function')
func = module.functions[node.func.id]
exp_arg_count = len(func.signature.args) - 1
if exp_arg_count != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {exp_arg_count} arguments but {len(node.args)} are given')
result = FunctionCall(func)
result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
@ -616,6 +614,22 @@ class OurVisitor:
_raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Subscript):
if isinstance(node.value, ast.Name) and node.value.id == 'Callable':
func_arg_types: list[ast.expr]
if isinstance(node.slice, ast.Name):
func_arg_types = [node.slice]
elif isinstance(node.slice, ast.Tuple):
func_arg_types = node.slice.elts
else:
_raise_static_error(node, 'Must subscript using a list of types')
# Function type
return prelude.function(*[
self.visit_type(module, e)
for e in func_arg_types
])
if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice, ast.Constant):

View File

@ -188,7 +188,9 @@ determined length, and each argument can be different.
"""
def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None:
pass # ? instance_type_class(InternalPassAsPointer, typ)
# Not really a pointer; but still a i32
# (It's actually a table lookup)
instance_type_class(InternalPassAsPointer, typ)
function = TypeConstructor_Function('function', on_create=fn_on_create)
"""
@ -473,7 +475,7 @@ instance_type_class(IntNum, f64, methods={
'neg': stdtypes.f64_intnum_neg,
})
Integral = Type3Class('Eq', (a, ), methods={
Integral = Type3Class('Integral', (a, ), methods={
}, operators={
'//': [a, a, a],
'%': [a, a, a],

View File

@ -159,7 +159,7 @@ class SameTypeConstraint(ConstraintBase):
return (
' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))),
{
't' + str(idx): typ.name if isinstance(typ, ourlang.Function) else typ
't' + str(idx): typ
for idx, typ in enumerate(self.type_list)
},
)
@ -257,9 +257,6 @@ class SameFunctionArgumentConstraint(ConstraintBase):
if isinstance(tv, Type3) or tv in type_var_map
]
print('self.func_arg.args', self.func_arg.args)
print('exp_type_arg_list', exp_type_arg_list)
if len(exp_type_arg_list) != len(self.func_arg.args):
return RequireTypeSubstitutes()

View File

@ -23,9 +23,10 @@ from .functions import (
FunctionSignature,
TypeVariable,
TypeVariableApplication_Unary,
TypeVariableContext,
)
from .placeholders import PlaceholderForType
from .types import Type3, TypeApplication_Struct
from .types import Type3, TypeApplication_Struct, TypeConstructor_Function
ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -56,15 +57,31 @@ def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderF
)
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp.function, ourlang.FunctionParam):
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
signature = FunctionSignature(
TypeVariableContext(),
inp.function.type3.application.arguments,
)
else:
signature = inp.function.signature
return _expression_function_call(
ctx,
inp.function.name,
inp.function.signature,
signature,
inp.arguments,
inp,
phft,
)
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
yield SameTypeConstraint(
prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3),
phft,
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
)
def _expression_function_call(
ctx: Context,
func_name: str,
@ -189,13 +206,6 @@ def _expression_function_call(
raise NotImplementedError(sig_part)
return
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
yield SameTypeConstraint(
prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3),
phft,
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
)
def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp, ourlang.Constant):
yield from constant(ctx, inp, phft)

View File

@ -241,7 +241,7 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar):
class TypeConstructor_Function(TypeConstructor_TypeStar):
def make_name(self, key: Tuple[Type3, ...]) -> str:
return '(' + ' -> '.join(x.name for x in key) + ')'
return 'Callable[' + ', '.join(x.name for x in key) + ']'
class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]):
"""

View File

@ -186,8 +186,8 @@ class Module(WatSerializable):
"""
def __init__(self) -> None:
self.imports: List[Import] = []
self.functions: List[Function] = []
self.table: dict[int, str] = {}
self.functions: List[Function] = []
self.memory = ModuleMemory()
def to_wat(self) -> str:

View File

@ -0,0 +1,182 @@
import pytest
from phasm.type3.entry import Type3Exception
from ..helpers import Suite
@pytest.mark.integration_test
def test_sof_in_code_0_arg():
code_py = """
def thirteen() -> i32:
return 13
def action(applicable: Callable[i32]) -> i32:
return applicable()
@exported
def testEntry() -> i32:
return action(thirteen)
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_1_arg():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
result = Suite(code_py).run_code()
assert 26 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_2_arg():
code_py = """
def add(left: i32, right: i32) -> i32:
return left + right
def action(applicable: Callable[i32, i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(add, 13, 14)
"""
result = Suite(code_py).run_code()
assert 27 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_3_arg():
code_py = """
def add(left: i32, mid: i32, right: i32) -> i32:
return left + mid + right
def action(applicable: Callable[i32, i32, i32, i32], left: i32, mid: i32, right: i32) -> i32:
return applicable(left, mid, right)
@exported
def testEntry() -> i32:
return action(add, 13, 14, 15)
"""
result = Suite(code_py).run_code()
assert 42 == result.returned_value
@pytest.mark.integration_test
def test_sof_wrong_argument_type():
code_py = """
def double(left: f32) -> f32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
with pytest.raises(Type3Exception, match=r'Callable\[f32, f32\] must be Callable\[i32, i32\] instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_return():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> f32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('FIXME: Probably have the remainder be the a function type')
def test_sof_wrong_not_enough_args_call():
code_py = """
def add(left: i32, right: i32) -> i32:
return left + right
def action(applicable: Callable[i32, i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(add, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_not_enough_args_refere():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(double, 13, 14)
"""
with pytest.raises(Type3Exception, match=r'Callable\[i32, i32\] must be Callable\[i32, i32, i32\] instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('FIXME: Probably have the remainder be the a function type')
def test_sof_wrong_too_many_args_call():
code_py = """
def thirteen() -> i32:
return 13
def action(applicable: Callable[i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(thirteen, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_too_many_args_refere():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32]) -> i32:
return applicable()
@exported
def testEntry() -> i32:
return action(double)
"""
with pytest.raises(Type3Exception, match=r'Callable\[i32, i32\] must be Callable\[i32\] instead'):
Suite(code_py).run_code()