Adds functions as passable values

This commit is contained in:
Johan B.W. de Vries 2025-05-17 19:43:52 +02:00
parent 99d2b22336
commit a72bd60de2
12 changed files with 464 additions and 55 deletions

View File

@ -85,6 +85,9 @@ def expression(inp: ourlang.Expression) -> str:
return f'{inp.function.name}({args})' return f'{inp.function.name}({args})'
if isinstance(inp, ourlang.FunctionReference):
return str(inp.function.name)
if isinstance(inp, ourlang.TupleInstantiation): if isinstance(inp, ourlang.TupleInstantiation):
args = ', '.join( args = ', '.join(
expression(arg) expression(arg)

View File

@ -17,6 +17,7 @@ from .type3.types import (
TypeApplication_Struct, TypeApplication_Struct,
TypeApplication_TypeInt, TypeApplication_TypeInt,
TypeApplication_TypeStar, TypeApplication_TypeStar,
TypeConstructor_Function,
TypeConstructor_StaticArray, TypeConstructor_StaticArray,
TypeConstructor_Tuple, TypeConstructor_Tuple,
) )
@ -100,7 +101,7 @@ def type3(inp: Type3) -> wasm.WasmType:
raise NotImplementedError(type3, inp) raise NotImplementedError(type3, inp)
def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> None: def tuple_instantiation(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.TupleInstantiation) -> None:
""" """
Compile: Instantiation (allocation) of a tuple Compile: Instantiation (allocation) of a tuple
""" """
@ -150,7 +151,7 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) ->
wgn.add_statement('nop', comment='PRE') wgn.add_statement('nop', comment='PRE')
wgn.local.get(tmp_var) wgn.local.get(tmp_var)
expression(wgn, element) expression(wgn, mod, element)
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset)) wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset))
wgn.add_statement('nop', comment='POST') wgn.add_statement('nop', comment='POST')
@ -160,29 +161,29 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) ->
wgn.local.get(tmp_var) wgn.local.get(tmp_var)
def expression_subscript_bytes( def expression_subscript_bytes(
attrs: tuple[WasmGenerator, ourlang.Subscript], attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
) -> None: ) -> None:
wgn, inp = attrs wgn, mod, inp = attrs
expression(wgn, inp.varref) expression(wgn, mod, inp.varref)
expression(wgn, inp.index) expression(wgn, mod, inp.index)
wgn.call(stdlib_types.__subscript_bytes__) wgn.call(stdlib_types.__subscript_bytes__)
def expression_subscript_static_array( def expression_subscript_static_array(
attrs: tuple[WasmGenerator, ourlang.Subscript], attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
args: tuple[Type3, IntType3], args: tuple[Type3, IntType3],
) -> None: ) -> None:
wgn, inp = attrs wgn, mod, inp = attrs
el_type, el_len = args el_type, el_len = args
# OPTIMIZE: If index is a constant, we can use offset instead of multiply # OPTIMIZE: If index is a constant, we can use offset instead of multiply
# and we don't need to do the out of bounds check # and we don't need to do the out of bounds check
expression(wgn, inp.varref) expression(wgn, mod, inp.varref)
tmp_var = wgn.temp_var_i32('index') tmp_var = wgn.temp_var_i32('index')
expression(wgn, inp.index) expression(wgn, mod, inp.index)
wgn.local.tee(tmp_var) wgn.local.tee(tmp_var)
# Out of bounds check based on el_len.value # Out of bounds check based on el_len.value
@ -201,10 +202,10 @@ def expression_subscript_static_array(
wgn.add_statement(f'{mtyp}.load') wgn.add_statement(f'{mtyp}.load')
def expression_subscript_tuple( def expression_subscript_tuple(
attrs: tuple[WasmGenerator, ourlang.Subscript], attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
args: tuple[Type3, ...], args: tuple[Type3, ...],
) -> None: ) -> None:
wgn, inp = attrs wgn, mod, inp = attrs
assert isinstance(inp.index, ourlang.ConstantPrimitive) assert isinstance(inp.index, ourlang.ConstantPrimitive)
assert isinstance(inp.index.value, int) assert isinstance(inp.index.value, int)
@ -217,7 +218,7 @@ def expression_subscript_tuple(
el_type = args[inp.index.value] el_type = args[inp.index.value]
assert el_type is not None, TYPE3_ASSERTION_ERROR assert el_type is not None, TYPE3_ASSERTION_ERROR
expression(wgn, inp.varref) expression(wgn, mod, inp.varref)
if (prelude.InternalPassAsPointer, (el_type, )) in prelude.PRELUDE_TYPE_CLASS_INSTANCES_EXISTING: if (prelude.InternalPassAsPointer, (el_type, )) in prelude.PRELUDE_TYPE_CLASS_INSTANCES_EXISTING:
mtyp = 'i32' mtyp = 'i32'
@ -226,12 +227,12 @@ def expression_subscript_tuple(
wgn.add_statement(f'{mtyp}.load', f'offset={offset}') wgn.add_statement(f'{mtyp}.load', f'offset={offset}')
SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Subscript], None]() SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], None]()
SUBSCRIPT_ROUTER.add_n(prelude.bytes_, expression_subscript_bytes) SUBSCRIPT_ROUTER.add_n(prelude.bytes_, expression_subscript_bytes)
SUBSCRIPT_ROUTER.add(prelude.static_array, expression_subscript_static_array) SUBSCRIPT_ROUTER.add(prelude.static_array, expression_subscript_static_array)
SUBSCRIPT_ROUTER.add(prelude.tuple_, expression_subscript_tuple) SUBSCRIPT_ROUTER.add(prelude.tuple_, expression_subscript_tuple)
def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) -> None:
""" """
Compile: Any expression Compile: Any expression
""" """
@ -291,14 +292,14 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
wgn.i32.const(address) wgn.i32.const(address)
return return
expression(wgn, inp.variable.constant) expression(wgn, mod, inp.variable.constant)
return return
raise NotImplementedError(expression, inp.variable) raise NotImplementedError(expression, inp.variable)
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
expression(wgn, inp.left) expression(wgn, mod, inp.left)
expression(wgn, inp.right) expression(wgn, mod, inp.right)
type_var_map: dict[TypeVariable, Type3] = {} type_var_map: dict[TypeVariable, Type3] = {}
@ -321,7 +322,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp, ourlang.FunctionCall): if isinstance(inp, ourlang.FunctionCall):
for arg in inp.arguments: for arg in inp.arguments:
expression(wgn, arg) expression(wgn, mod, arg)
if isinstance(inp.function, Type3ClassMethod): if isinstance(inp.function, Type3ClassMethod):
# FIXME: Duplicate code with BinaryOp # FIXME: Duplicate code with BinaryOp
@ -347,18 +348,42 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
raise NotImplementedError(str(inp.function), type_var_map) raise NotImplementedError(str(inp.function), type_var_map)
return return
if isinstance(inp.function, ourlang.FunctionParam):
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
params = [
type3(x).to_wat()
for x in inp.function.type3.application.arguments
]
result = params.pop()
params_str = ' '.join(params)
wgn.add_statement('local.get', '${}'.format(inp.function.name))
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
return
wgn.add_statement('call', '${}'.format(inp.function.name)) wgn.add_statement('call', '${}'.format(inp.function.name))
return return
if isinstance(inp, ourlang.FunctionReference):
idx = mod.functions_table.get(inp.function)
if idx is None:
idx = len(mod.functions_table)
mod.functions_table[inp.function] = idx
wgn.add_statement('i32.const', str(idx), comment=inp.function.name)
return
if isinstance(inp, ourlang.TupleInstantiation): if isinstance(inp, ourlang.TupleInstantiation):
tuple_instantiation(wgn, inp) tuple_instantiation(wgn, mod, inp)
return return
if isinstance(inp, ourlang.Subscript): if isinstance(inp, ourlang.Subscript):
assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR
# Type checker guarantees we don't get routing errors # Type checker guarantees we don't get routing errors
SUBSCRIPT_ROUTER((wgn, inp, ), inp.varref.type3) SUBSCRIPT_ROUTER((wgn, mod, inp, ), inp.varref.type3)
return return
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
@ -370,19 +395,19 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
mtyp = LOAD_STORE_TYPE_MAP[member_type.name] mtyp = LOAD_STORE_TYPE_MAP[member_type.name]
expression(wgn, inp.varref) expression(wgn, mod, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(calculate_member_offset( wgn.add_statement(f'{mtyp}.load', 'offset=' + str(calculate_member_offset(
inp.struct_type3.name, inp.struct_type3.application.arguments, inp.member inp.struct_type3.name, inp.struct_type3.application.arguments, inp.member
))) )))
return return
if isinstance(inp, ourlang.Fold): if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp) expression_fold(wgn, mod, inp)
return return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: def expression_fold(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Fold) -> None:
""" """
Compile: Fold expression Compile: Fold expression
""" """
@ -399,11 +424,11 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
len_var = wgn.temp_var_i32('fold_i32_len') len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base') wgn.add_statement('nop', comment='acu = base')
expression(wgn, inp.base) expression(wgn, mod, inp.base)
wgn.local.set(acu_var) wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)') wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, inp.iter) expression(wgn, mod, inp.iter)
wgn.local.set(adr_var) wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)') wgn.add_statement('nop', comment='len = len(iter)')
@ -460,21 +485,21 @@ def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
# return acu # return acu
wgn.local.get(acu_var) wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None: def statement_return(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementReturn) -> None:
""" """
Compile: Return statement Compile: Return statement
""" """
expression(wgn, inp.value) expression(wgn, mod, inp.value)
wgn.return_() wgn.return_()
def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None: def statement_if(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementIf) -> None:
""" """
Compile: If statement Compile: If statement
""" """
expression(wgn, inp.test) expression(wgn, mod, inp.test)
with wgn.if_(): with wgn.if_():
for stat in inp.statements: for stat in inp.statements:
statement(wgn, stat) statement(wgn, mod, stat)
if inp.else_statements: if inp.else_statements:
raise NotImplementedError raise NotImplementedError
@ -482,16 +507,16 @@ def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None:
# for stat in inp.else_statements: # for stat in inp.else_statements:
# statement(wgn, stat) # statement(wgn, stat)
def statement(wgn: WasmGenerator, inp: ourlang.Statement) -> None: def statement(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Statement) -> None:
""" """
Compile: any statement Compile: any statement
""" """
if isinstance(inp, ourlang.StatementReturn): if isinstance(inp, ourlang.StatementReturn):
statement_return(wgn, inp) statement_return(wgn, mod, inp)
return return
if isinstance(inp, ourlang.StatementIf): if isinstance(inp, ourlang.StatementIf):
statement_if(wgn, inp) statement_if(wgn, mod, inp)
return return
if isinstance(inp, ourlang.StatementPass): if isinstance(inp, ourlang.StatementPass):
@ -522,7 +547,7 @@ def import_(inp: ourlang.Function) -> wasm.Import:
type3(inp.returns_type3) type3(inp.returns_type3)
) )
def function(inp: ourlang.Function) -> wasm.Function: def function(mod: ourlang.Module, inp: ourlang.Function) -> wasm.Function:
""" """
Compile: function Compile: function
""" """
@ -534,7 +559,7 @@ def function(inp: ourlang.Function) -> wasm.Function:
_generate_struct_constructor(wgn, inp) _generate_struct_constructor(wgn, inp)
else: else:
for stat in inp.statements: for stat in inp.statements:
statement(wgn, stat) statement(wgn, mod, stat)
return wasm.Function( return wasm.Function(
inp.name, inp.name,
@ -739,11 +764,17 @@ def module(inp: ourlang.Module) -> wasm.Module:
stdlib_types.__u8_rotl__, stdlib_types.__u8_rotl__,
stdlib_types.__u8_rotr__, stdlib_types.__u8_rotr__,
] + [ ] + [
function(x) function(inp, x)
for x in inp.functions.values() for x in inp.functions.values()
if not x.imported if not x.imported
] ]
# Do this after rendering the functions since that's what populates the tables
result.table = {
v: k.name
for k, v in inp.functions_table.items()
}
return result return result
def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None:

View File

@ -152,15 +152,27 @@ class FunctionCall(Expression):
""" """
__slots__ = ('function', 'arguments', ) __slots__ = ('function', 'arguments', )
function: Union['Function', Type3ClassMethod] function: Union['Function', 'FunctionParam', Type3ClassMethod]
arguments: List[Expression] arguments: List[Expression]
def __init__(self, function: Union['Function', Type3ClassMethod]) -> None: def __init__(self, function: Union['Function', 'FunctionParam', Type3ClassMethod]) -> None:
super().__init__() super().__init__()
self.function = function self.function = function
self.arguments = [] self.arguments = []
class FunctionReference(Expression):
"""
An function reference expression within a statement
"""
__slots__ = ('function', )
function: 'Function'
def __init__(self, function: 'Function') -> None:
super().__init__()
self.function = function
class TupleInstantiation(Expression): class TupleInstantiation(Expression):
""" """
Instantiation a tuple Instantiation a tuple
@ -397,7 +409,7 @@ class Module:
""" """
A module is a file and consists of functions A module is a file and consists of functions
""" """
__slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', ) __slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', 'functions_table', )
data: ModuleData data: ModuleData
types: dict[str, Type3] types: dict[str, Type3]
@ -405,6 +417,7 @@ class Module:
constant_defs: Dict[str, ModuleConstantDef] constant_defs: Dict[str, ModuleConstantDef]
functions: Dict[str, Function] functions: Dict[str, Function]
operators: Dict[str, Type3ClassMethod] operators: Dict[str, Type3ClassMethod]
functions_table: dict[Function, int]
def __init__(self) -> None: def __init__(self) -> None:
self.data = ModuleData() self.data = ModuleData()
@ -413,3 +426,4 @@ class Module:
self.constant_defs = {} self.constant_defs = {}
self.functions = {} self.functions = {}
self.operators = {} self.operators = {}
self.functions_table = {}

View File

@ -18,6 +18,7 @@ from .ourlang import (
Function, Function,
FunctionCall, FunctionCall,
FunctionParam, FunctionParam,
FunctionReference,
Module, Module,
ModuleConstantDef, ModuleConstantDef,
ModuleDataBlock, ModuleDataBlock,
@ -446,6 +447,10 @@ class OurVisitor:
cdef = module.constant_defs[node.id] cdef = module.constant_defs[node.id]
return VariableReference(cdef) return VariableReference(cdef)
if node.id in module.functions:
fun = module.functions[node.id]
return FunctionReference(fun)
_raise_static_error(node, f'Undefined variable {node.id}') _raise_static_error(node, f'Undefined variable {node.id}')
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
@ -471,7 +476,7 @@ class OurVisitor:
if not isinstance(node.func.ctx, ast.Load): if not isinstance(node.func.ctx, ast.Load):
_raise_static_error(node, 'Must be load context') _raise_static_error(node, 'Must be load context')
func: Union[Function, Type3ClassMethod] func: Union[Function, FunctionParam, Type3ClassMethod]
if node.func.id in PRELUDE_METHODS: if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id] func = PRELUDE_METHODS[node.func.id]
@ -497,17 +502,14 @@ class OurVisitor:
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
) )
elif node.func.id in our_locals:
func = our_locals[node.func.id]
else: else:
if node.func.id not in module.functions: if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function') _raise_static_error(node, 'Call to undefined function')
func = module.functions[node.func.id] func = module.functions[node.func.id]
exp_arg_count = len(func.signature.args) - 1
if exp_arg_count != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires {exp_arg_count} arguments but {len(node.args)} are given')
result = FunctionCall(func) result = FunctionCall(func)
result.arguments.extend( result.arguments.extend(
self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr) self.visit_Module_FunctionDef_expr(module, function, our_locals, arg_expr)
@ -635,6 +637,22 @@ class OurVisitor:
_raise_static_error(node, f'Unrecognized type {node.id}') _raise_static_error(node, f'Unrecognized type {node.id}')
if isinstance(node, ast.Subscript): if isinstance(node, ast.Subscript):
if isinstance(node.value, ast.Name) and node.value.id == 'Callable':
func_arg_types: list[ast.expr]
if isinstance(node.slice, ast.Name):
func_arg_types = [node.slice]
elif isinstance(node.slice, ast.Tuple):
func_arg_types = node.slice.elts
else:
_raise_static_error(node, 'Must subscript using a list of types')
# Function type
return prelude.function(*[
self.visit_type(module, e)
for e in func_arg_types
])
if isinstance(node.slice, ast.Slice): if isinstance(node.slice, ast.Slice):
_raise_static_error(node, 'Must subscript using an index') _raise_static_error(node, 'Must subscript using an index')
if not isinstance(node.slice, ast.Constant): if not isinstance(node.slice, ast.Constant):

View File

@ -20,6 +20,7 @@ from ..type3.types import (
Type3, Type3,
TypeApplication_Nullary, TypeApplication_Nullary,
TypeConstructor_Base, TypeConstructor_Base,
TypeConstructor_Function,
TypeConstructor_StaticArray, TypeConstructor_StaticArray,
TypeConstructor_Struct, TypeConstructor_Struct,
TypeConstructor_Tuple, TypeConstructor_Tuple,
@ -186,6 +187,18 @@ It should be applied with zero or more arguments. It has a compile time
determined length, and each argument can be different. determined length, and each argument can be different.
""" """
def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None:
# Not really a pointer; but still a i32
# (It's actually a table lookup)
instance_type_class(InternalPassAsPointer, typ)
function = TypeConstructor_Function('function', on_create=fn_on_create)
"""
This is a function.
It should be applied with one or more arguments. The last argument is the 'return' type.
"""
def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None: def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None:
instance_type_class(InternalPassAsPointer, typ) instance_type_class(InternalPassAsPointer, typ)

View File

@ -6,6 +6,7 @@ These need to be resolved before the program can be compiled.
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from .. import ourlang, prelude from .. import ourlang, prelude
from .functions import FunctionArgument, TypeVariable
from .placeholders import PlaceholderForType, Type3OrPlaceholder from .placeholders import PlaceholderForType, Type3OrPlaceholder
from .routers import NoRouteForTypeException, TypeApplicationRouter from .routers import NoRouteForTypeException, TypeApplicationRouter
from .typeclasses import Type3Class from .typeclasses import Type3Class
@ -204,10 +205,81 @@ class SameTypeArgumentConstraint(ConstraintBase):
if tc_typ.application.arguments[0] == arg_typ: if tc_typ.application.arguments[0] == arg_typ:
return None return None
return Error(f'{tc_typ.application.arguments[0]:s} must be {arg_typ:s} instead') return [SameTypeConstraint(
tc_typ.application.arguments[0],
self.arg_var,
comment=self.comment,
)]
raise NotImplementedError(tc_typ, arg_typ) raise NotImplementedError(tc_typ, arg_typ)
def human_readable(self) -> HumanReadableRet:
return (
'{tc_var}` == {arg_var}',
{
'tc_var': self.tc_var if self.tc_var.resolve_as is None else self.tc_var,
'arg_var': self.arg_var if self.arg_var.resolve_as is None else self.arg_var,
},
)
class SameFunctionArgumentConstraint(ConstraintBase):
__slots__ = ('type3', 'func_arg', 'type_var_map', )
type3: PlaceholderForType
func_arg: FunctionArgument
type_var_map: dict[TypeVariable, PlaceholderForType]
def __init__(self, type3: PlaceholderForType, func_arg: FunctionArgument, type_var_map: dict[TypeVariable, PlaceholderForType], *, comment: str) -> None:
super().__init__(comment=comment)
self.type3 = type3
self.func_arg = func_arg
self.type_var_map = type_var_map
def check(self) -> CheckResult:
if self.type3.resolve_as is None:
return RequireTypeSubstitutes()
typ = self.type3.resolve_as
if isinstance(typ.application, TypeApplication_Nullary):
return Error(f'{typ:s} must be a function instead')
if not isinstance(typ.application, TypeApplication_TypeStar):
return Error(f'{typ:s} must be a function instead')
type_var_map = {
x: y.resolve_as
for x, y in self.type_var_map.items()
if y.resolve_as is not None
}
exp_type_arg_list = [
tv if isinstance(tv, Type3) else type_var_map[tv]
for tv in self.func_arg.args
if isinstance(tv, Type3) or tv in type_var_map
]
if len(exp_type_arg_list) != len(self.func_arg.args):
return RequireTypeSubstitutes()
return [
SameTypeConstraint(
typ,
prelude.function(*exp_type_arg_list),
comment=self.comment,
)
]
def human_readable(self) -> HumanReadableRet:
return (
'{type3} == {func_arg}',
{
'type3': self.type3,
'func_arg': self.func_arg.name,
},
)
class TupleMatchConstraint(ConstraintBase): class TupleMatchConstraint(ConstraintBase):
__slots__ = ('exp_type', 'args', ) __slots__ = ('exp_type', 'args', )

View File

@ -12,18 +12,21 @@ from .constraints import (
Context, Context,
LiteralFitsConstraint, LiteralFitsConstraint,
MustImplementTypeClassConstraint, MustImplementTypeClassConstraint,
SameFunctionArgumentConstraint,
SameTypeArgumentConstraint, SameTypeArgumentConstraint,
SameTypeConstraint, SameTypeConstraint,
TupleMatchConstraint, TupleMatchConstraint,
) )
from .functions import ( from .functions import (
Constraint_TypeClassInstanceExists, Constraint_TypeClassInstanceExists,
FunctionArgument,
FunctionSignature, FunctionSignature,
TypeVariable, TypeVariable,
TypeVariableApplication_Unary, TypeVariableApplication_Unary,
TypeVariableContext,
) )
from .placeholders import PlaceholderForType from .placeholders import PlaceholderForType
from .types import Type3, TypeApplication_Struct from .types import Type3, TypeApplication_Struct, TypeConstructor_Function
ConstraintGenerator = Generator[ConstraintBase, None, None] ConstraintGenerator = Generator[ConstraintBase, None, None]
@ -54,15 +57,31 @@ def expression_binary_op(ctx: Context, inp: ourlang.BinaryOp, phft: PlaceholderF
) )
def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator: def expression_function_call(ctx: Context, inp: ourlang.FunctionCall, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp.function, ourlang.FunctionParam):
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
signature = FunctionSignature(
TypeVariableContext(),
inp.function.type3.application.arguments,
)
else:
signature = inp.function.signature
return _expression_function_call( return _expression_function_call(
ctx, ctx,
inp.function.name, inp.function.name,
inp.function.signature, signature,
inp.arguments, inp.arguments,
inp, inp,
phft, phft,
) )
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
yield SameTypeConstraint(
prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3),
phft,
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
)
def _expression_function_call( def _expression_function_call(
ctx: Context, ctx: Context,
func_name: str, func_name: str,
@ -111,6 +130,33 @@ def _expression_function_call(
raise NotImplementedError(constraint) raise NotImplementedError(constraint)
func_var_map = {
x: PlaceholderForType([])
for x in signature.args
if isinstance(x, FunctionArgument)
}
# If some of the function arguments are functions,
# we need to deal with those separately.
for sig_arg in signature.args:
if not isinstance(sig_arg, FunctionArgument):
continue
# Ensure that for all type variables in the function
# there are also type variables available
for func_arg in sig_arg.args:
if isinstance(func_arg, Type3):
continue
type_var_map.setdefault(func_arg, PlaceholderForType([]))
yield SameFunctionArgumentConstraint(
func_var_map[sig_arg],
sig_arg,
type_var_map,
comment=f'Ensure `{sig_arg.name}` matches in {signature}',
)
# If some of the function arguments are type constructors, # If some of the function arguments are type constructors,
# we need to deal with those separately. # we need to deal with those separately.
# That is, given `foo :: t a -> a` we need to ensure # That is, given `foo :: t a -> a` we need to ensure
@ -120,6 +166,9 @@ def _expression_function_call(
# Not a type variable at all # Not a type variable at all
continue continue
if isinstance(sig_arg, FunctionArgument):
continue
if sig_arg.application.constructor is None: if sig_arg.application.constructor is None:
# Not a type variable for a type constructor # Not a type variable for a type constructor
continue continue
@ -171,6 +220,10 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
yield from expression_function_call(ctx, inp, phft) yield from expression_function_call(ctx, inp, phft)
return return
if isinstance(inp, ourlang.FunctionReference):
yield from expression_function_reference(ctx, inp, phft)
return
if isinstance(inp, ourlang.TupleInstantiation): if isinstance(inp, ourlang.TupleInstantiation):
r_type = [] r_type = []
for arg in inp.elements: for arg in inp.elements:

View File

@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List, Union from __future__ import annotations
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List
if TYPE_CHECKING: if TYPE_CHECKING:
from .typeclasses import Type3Class from .typeclasses import Type3Class
@ -155,15 +157,29 @@ class TypeVariableContext:
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TypeVariableContext({self.constraints!r})' return f'TypeVariableContext({self.constraints!r})'
class FunctionArgument:
__slots__ = ('args', 'name', )
args: list[Type3 | TypeVariable]
name: str
def __init__(self, args: list[Type3 | TypeVariable]) -> None:
self.args = args
self.name = '(' + ' -> '.join(x.name for x in args) + ')'
class FunctionSignature: class FunctionSignature:
__slots__ = ('context', 'args', ) __slots__ = ('context', 'args', )
context: TypeVariableContext context: TypeVariableContext
args: List[Union['Type3', TypeVariable]] args: List[Type3 | TypeVariable | FunctionArgument]
def __init__(self, context: TypeVariableContext, args: Iterable[Union['Type3', TypeVariable]]) -> None: def __init__(self, context: TypeVariableContext, args: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]) -> None:
self.context = context.__copy__() self.context = context.__copy__()
self.args = list(args) self.args = list(
FunctionArgument(x) if isinstance(x, list) else x
for x in args
)
def __str__(self) -> str: def __str__(self) -> str:
return str(self.context) + ' -> '.join(x.name for x in self.args) return str(self.context) + ' -> '.join(x.name for x in self.args)

View File

@ -239,6 +239,10 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar):
def make_name(self, key: Tuple[Type3, ...]) -> str: def make_name(self, key: Tuple[Type3, ...]) -> str:
return '(' + ', '.join(x.name for x in key) + ', )' return '(' + ', '.join(x.name for x in key) + ', )'
class TypeConstructor_Function(TypeConstructor_TypeStar):
def make_name(self, key: Tuple[Type3, ...]) -> str:
return 'Callable[' + ', '.join(x.name for x in key) + ']'
class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]): class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]):
""" """
Constructs struct types Constructs struct types

View File

@ -186,6 +186,7 @@ class Module(WatSerializable):
""" """
def __init__(self) -> None: def __init__(self) -> None:
self.imports: List[Import] = [] self.imports: List[Import] = []
self.table: dict[int, str] = {}
self.functions: List[Function] = [] self.functions: List[Function] = []
self.memory = ModuleMemory() self.memory = ModuleMemory()
@ -193,8 +194,10 @@ class Module(WatSerializable):
""" """
Generates the text version Generates the text version
""" """
return '(module\n {}\n {}\n {})\n'.format( return '(module\n {}\n {}\n {}\n {}\n {})\n'.format(
'\n '.join(x.to_wat() for x in self.imports), '\n '.join(x.to_wat() for x in self.imports),
f'(table {len(self.table)} funcref)',
'\n '.join(f'(elem (i32.const {k}) ${v})' for k, v in self.table.items()),
self.memory.to_wat(), self.memory.to_wat(),
'\n '.join(x.to_wat() for x in self.functions), '\n '.join(x.to_wat() for x in self.functions),
) )

View File

@ -0,0 +1,182 @@
import pytest
from phasm.type3.entry import Type3Exception
from ..helpers import Suite
@pytest.mark.integration_test
def test_sof_in_code_0_arg():
code_py = """
def thirteen() -> i32:
return 13
def action(applicable: Callable[i32]) -> i32:
return applicable()
@exported
def testEntry() -> i32:
return action(thirteen)
"""
result = Suite(code_py).run_code()
assert 13 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_1_arg():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
result = Suite(code_py).run_code()
assert 26 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_2_arg():
code_py = """
def add(left: i32, right: i32) -> i32:
return left + right
def action(applicable: Callable[i32, i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(add, 13, 14)
"""
result = Suite(code_py).run_code()
assert 27 == result.returned_value
@pytest.mark.integration_test
def test_sof_in_code_3_arg():
code_py = """
def add(left: i32, mid: i32, right: i32) -> i32:
return left + mid + right
def action(applicable: Callable[i32, i32, i32, i32], left: i32, mid: i32, right: i32) -> i32:
return applicable(left, mid, right)
@exported
def testEntry() -> i32:
return action(add, 13, 14, 15)
"""
result = Suite(code_py).run_code()
assert 42 == result.returned_value
@pytest.mark.integration_test
def test_sof_wrong_argument_type():
code_py = """
def double(left: f32) -> f32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
with pytest.raises(Type3Exception, match=r'Callable\[f32, f32\] must be Callable\[i32, i32\] instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_return():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32], left: i32) -> f32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(double, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('FIXME: Probably have the remainder be the a function type')
def test_sof_wrong_not_enough_args_call():
code_py = """
def add(left: i32, right: i32) -> i32:
return left + right
def action(applicable: Callable[i32, i32, i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(add, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_not_enough_args_refere():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32, i32, i32], left: i32, right: i32) -> i32:
return applicable(left, right)
@exported
def testEntry() -> i32:
return action(double, 13, 14)
"""
with pytest.raises(Type3Exception, match=r'Callable\[i32, i32\] must be Callable\[i32, i32, i32\] instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.skip('FIXME: Probably have the remainder be the a function type')
def test_sof_wrong_too_many_args_call():
code_py = """
def thirteen() -> i32:
return 13
def action(applicable: Callable[i32], left: i32) -> i32:
return applicable(left)
@exported
def testEntry() -> i32:
return action(thirteen, 13)
"""
with pytest.raises(Type3Exception, match=r'f32 must be i32 instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_sof_wrong_too_many_args_refere():
code_py = """
def double(left: i32) -> i32:
return left * 2
def action(applicable: Callable[i32]) -> i32:
return applicable()
@exported
def testEntry() -> i32:
return action(double)
"""
with pytest.raises(Type3Exception, match=r'Callable\[i32, i32\] must be Callable\[i32\] instead'):
Suite(code_py).run_code()

View File

@ -45,7 +45,7 @@ def testEntry(x: i32[5]) -> f64:
return sum(x) return sum(x)
""" """
with pytest.raises(Type3Exception, match='i32 must be f64 instead'): with pytest.raises(Type3Exception, match='f64 must be i32 instead'):
Suite(code_py).run_code((4, 5, 6, 7, 8, )) Suite(code_py).run_code((4, 5, 6, 7, 8, ))
@pytest.mark.integration_test @pytest.mark.integration_test