Compare commits

..

3 Commits

Author SHA1 Message Date
Johan B.W. de Vries
b285eb9d05 Removes the special casing for foldl
Had to implement both functions as arguments and type
place holders (variables) for type constructors.

Had to implement functions as a type as well.

Still have to figure out how to pass functions around.
2025-05-17 19:50:59 +02:00
Johan B.W. de Vries
a72bd60de2 Adds functions as passable values 2025-05-17 19:43:52 +02:00
Johan B.W. de Vries
99d2b22336 Moved the typeclasse tests. Fix typeclass name. 2025-05-17 18:42:27 +02:00
22 changed files with 252 additions and 31 deletions

View File

@ -17,6 +17,7 @@ from .type3.types import (
TypeApplication_Struct, TypeApplication_Struct,
TypeApplication_TypeInt, TypeApplication_TypeInt,
TypeApplication_TypeStar, TypeApplication_TypeStar,
TypeConstructor_Function,
TypeConstructor_StaticArray, TypeConstructor_StaticArray,
TypeConstructor_Tuple, TypeConstructor_Tuple,
) )
@ -313,9 +314,9 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
type_var_map[type_var] = arg_expr.type3 type_var_map[type_var] = arg_expr.type3
continue continue
if isinstance(type_var, FunctionArgument): if isinstance(type_var, FunctionArgument):
# Fixed type, not part of the lookup requirements # Fixed type, not part of the lookup requirements
continue continue
raise NotImplementedError(type_var, arg_expr.type3) raise NotImplementedError(type_var, arg_expr.type3)
@ -355,6 +356,21 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
raise NotImplementedError(str(inp.function), type_var_map) raise NotImplementedError(str(inp.function), type_var_map)
return return
if isinstance(inp.function, ourlang.FunctionParam):
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
params = [
type3(x).to_wat()
for x in inp.function.type3.application.arguments
]
result = params.pop()
params_str = ' '.join(params)
wgn.add_statement('local.get', '${}'.format(inp.function.name))
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
return
wgn.add_statement('call', '${}'.format(inp.function.name)) wgn.add_statement('call', '${}'.format(inp.function.name))
return return

View File

@ -151,10 +151,10 @@ class FunctionCall(Expression):
""" """
__slots__ = ('function', 'arguments', ) __slots__ = ('function', 'arguments', )
function: Union['Function', Type3ClassMethod] function: Union['Function', 'FunctionParam', Type3ClassMethod]
arguments: List[Expression] arguments: List[Expression]
def __init__(self, function: Union['Function', Type3ClassMethod]) -> None: def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod]) -> None:
super().__init__() super().__init__()
self.function = function self.function = function
@ -378,15 +378,15 @@ class Module:
""" """
A module is a file and consists of functions A module is a file and consists of functions
""" """
__slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'functions_table', 'operators', ) __slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', 'functions_table', )
data: ModuleData data: ModuleData
types: dict[str, Type3] types: dict[str, Type3]
struct_definitions: Dict[str, StructDefinition] struct_definitions: Dict[str, StructDefinition]
constant_defs: Dict[str, ModuleConstantDef] constant_defs: Dict[str, ModuleConstantDef]
functions: Dict[str, Function] functions: Dict[str, Function]
functions_table: dict[Function, int]
operators: Dict[str, Type3ClassMethod] operators: Dict[str, Type3ClassMethod]
functions_table: dict[Function, int]
def __init__(self) -> None: def __init__(self) -> None:
self.data = ModuleData() self.data = ModuleData()

View File

@ -447,7 +447,8 @@ class OurVisitor:
return VariableReference(cdef) return VariableReference(cdef)
if node.id in module.functions: if node.id in module.functions:
return FunctionReference(module.functions[node.id]) fun = module.functions[node.id]
return FunctionReference(fun)
_raise_static_error(node, f'Undefined variable {node.id}') _raise_static_error(node, f'Undefined variable {node.id}')
@ -474,21 +475,18 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load): if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
func: Union[Function, Type3ClassMethod] func: Union[Function, FunctionParam, Type3ClassMethod]
if node.func.id in PRELUDE_METHODS: if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id] func = PRELUDE_METHODS[node.func.id]
elif node.func.id in our_locals:
func = our_locals[node.func.id]
else: else:
if node.func.id not in module.functions: if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function') _raise_static_error(node, 'Call to undefined function')
func = module.functions[node.func.id] func = module.functions[node.func.id]
exp_arg_count = len(func.signature.args) - 1
if exp_arg_count != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {exp_arg_count} arguments but {len(node.args)} are given')
result = FunctionCall(func) result = FunctionCall(func)
result.arguments.extend( result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
@ -616,6 +614,22 @@ class OurVisitor:
_raise_static_error(node, f'Unrecognized type {node.id}') _raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Subscript): if isinstance(node, ast.Subscript):
if isinstance(node.value, ast.Name) and node.value.id == 'Callable':
func_arg_types: list[ast.expr]
if isinstance(node.slice, ast.Name):
func_arg_types = [node.slice]
elif isinstance(node.slice, ast.Tuple):
func_arg_types = node.slice.elts
else:
_raise_static_error(node, 'Must subscript using a list of types')
# Function type
return prelude.function(*[
self.visit_type(module, e)
for e in func_arg_types
])
if isinstance(node.slice, ast.Slice): if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index') _raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice, ast.Constant): if not isinstance(node.slice, ast.Constant):

View File

@ -188,7 +188,9 @@ determined length, and each argument can be different.
""" """
def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None: def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None:
pass # ? instance_type_class(InternalPassAsPointer, typ) # Not really a pointer; but still a i32
# (It's actually a table lookup)
instance_type_class(InternalPassAsPointer, typ)
function = TypeConstructor_Function('function', on_create=fn_on_create) function = TypeConstructor_Function('function', on_create=fn_on_create)
""" """
@ -473,7 +475,7 @@ instance_type_class(IntNum, f64, methods={
'neg': stdtypes.f64_intnum_neg, 'neg': stdtypes.f64_intnum_neg,
}) })
Integral = Type3Class('Eq', (a, ), methods={ Integral = Type3Class('Integral', (a, ), methods={
}, operators={ }, operators={
'//': [a, a, a], '//': [a, a, a],
'%': [a, a, a], '%': [a, a, a],

View File

@ -159,7 +159,7 @@ class SameTypeConstraint(ConstraintBase):
return ( return (
' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))), ' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))),
{ {
't' + str(idx): typ.name if isinstance(typ, ourlang.Function) else typ 't' + str(idx): typ
for idx, typ in enumerate(self.type_list) for idx, typ in enumerate(self.type_list)
}, },
) )
@ -257,9 +257,6 @@ class SameFunctionArgumentConstraint(ConstraintBase):
if isinstance(tv, Type3) or tv in type_var_map if isinstance(tv, Type3) or tv in type_var_map
] ]
print('self.func_arg.args', self.func_arg.args)
print('exp_type_arg_list', exp_type_arg_list)
if len(exp_type_arg_list) != len(self.func_arg.args): if len(exp_type_arg_list) != len(self.func_arg.args):
return RequireTypeSubstitutes() return RequireTypeSubstitutes()

View File

@ -23,9 +23,10 @@ from .functions import (
FunctionSignature, FunctionSignature,
TypeVariable, TypeVariable,
TypeVariableApplication_Unary, TypeVariableApplication_Unary,
TypeVariableContext,
) )
from .placeholders import PlaceholderForType from .placeholders import PlaceholderForType
from .types import Type3, TypeApplication_Struct from .types import Type3, TypeApplication_Struct, TypeConstructor_Function
ConstraintGenerator = Generator[ConstraintBase, None, None] ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -56,15 +57,31 @@ def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderF
) )
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator: def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp.function, ourlang.FunctionParam):
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
signature = FunctionSignature(
TypeVariableContext(),
inp.function.type3.application.arguments,
)
else:
signature = inp.function.signature
return _expression_function_call( return _expression_function_call(
ctx, ctx,
inp.function.name, inp.function.name,
inp.function.signature, signature,
inp.arguments, inp.arguments,
inp, inp,
phft, phft,
) )
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
yield SameTypeConstraint(
prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3),
phft,
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
)
def _expression_function_call( def _expression_function_call(
ctx: Context, ctx: Context,
func_name: str, func_name: str,
@ -189,13 +206,6 @@ def _expression_function_call(
raise NotImplementedError(sig_part) raise NotImplementedError(sig_part)
return return
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
yield SameTypeConstraint(
prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3),
phft,
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
)
def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator: def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp, ourlang.Constant): if isinstance(inp, ourlang.Constant):
yield from constant(ctx, inp, phft) yield from constant(ctx, inp, phft)

View File

@ -241,7 +241,7 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar):
class TypeConstructor_Function(TypeConstructor_TypeStar): class TypeConstructor_Function(TypeConstructor_TypeStar):
def make_name(self, key: Tuple[Type3, ...]) -> str: def make_name(self, key: Tuple[Type3, ...]) -> str:
return '(' + ' -> '.join(x.name for x in key) + ')' return 'Callable[' + ', '.join(x.name for x in key) + ']'
class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]): class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]):
""" """

View File

@ -186,8 +186,8 @@ class Module(WatSerializable):
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.imports: List[Import] = [] self.imports: List[Import] = []
self.functions: List[Function] = []
self.table: dict[int, str] = {} self.table: dict[int, str] = {}
self.functions: List[Function] = []
self.memory = ModuleMemory() self.memory = ModuleMemory()
def to_wat(self) -> str: def to_wat(self) -> str:

View File

@ -0,0 +1,182 @@
import pytest
from phasm.type3.entry import Type3Exception
from ..helpers import Suite
@pytest.mark.integration_test
def test_sof_in_code_0_arg():
code_py = """
def thirteen() -> i32:
return 13
def action(applicable: Callable[i32]) -> i32:
return applicable()
@exported
def testEntry() -> i32:
return action(thirteen)
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_1_arg():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
result = Suite(code_py).run_code()
assert 26 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_2_arg():
code_py = """
def add(left: i32, right: i32) -> i32:
return left + right
def action(applicable: Callable[i32, i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(add, 13, 14)
"""
result = Suite(code_py).run_code()
assert 27 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_3_arg():
code_py = """
def add(left: i32, mid: i32, right: i32) -> i32:
return left + mid + right
def action(applicable: Callable[i32, i32, i32, i32], left: i32, mid: i32, right: i32) -> i32:
return applicable(left, mid, right)
@exported
def testEntry() -> i32:
return action(add, 13, 14, 15)
"""
result = Suite(code_py).run_code()
assert 42 == result.returned_value
@pytest.mark.integration_test
def test_sof_wrong_argument_type():
code_py = """
def double(left: f32) -> f32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
with pytest.raises(Type3Exception, match=r'Callable\[f32, f32\] must be Callable\[i32, i32\] instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_return():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> f32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('FIXME: Probably have the remainder be the a function type')
def test_sof_wrong_not_enough_args_call():
code_py = """
def add(left: i32, right: i32) -> i32:
return left + right
def action(applicable: Callable[i32, i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(add, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_not_enough_args_refere():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(double, 13, 14)
"""
with pytest.raises(Type3Exception, match=r'Callable\[i32, i32\] must be Callable\[i32, i32, i32\] instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('FIXME: Probably have the remainder be the a function type')
def test_sof_wrong_too_many_args_call():
code_py = """
def thirteen() -> i32:
return 13
def action(applicable: Callable[i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(thirteen, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_too_many_args_refere():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32]) -> i32:
return applicable()
@exported
def testEntry() -> i32:
return action(double)
"""
with pytest.raises(Type3Exception, match=r'Callable\[i32, i32\] must be Callable\[i32\] instead'):
Suite(code_py).run_code()