Removes the special casing for foldl

Had to implement both functions as arguments and type
place holders (variables) for type constructors.

Had to implement functions as a type as well.

Still have to figure out how to pass functions around.
This commit is contained in:
Johan B.W. de Vries 2025-04-27 12:54:34 +02:00
parent ac4b46bbe7
commit bfb3d2b3a0
17 changed files with 675 additions and 306 deletions

View File

@ -30,3 +30,5 @@
- Functions don't seem to be a thing on typing level yet? - Functions don't seem to be a thing on typing level yet?
- Related to the FIXME in phasm_type3? - Related to the FIXME in phasm_type3?
- Type constuctor should also be able to constuct placeholders - somehow. - Type constuctor should also be able to constuct placeholders - somehow.
- Read https://bytecodealliance.org/articles/multi-value-all-the-wasm

View File

@ -85,6 +85,9 @@ def expression(inp: ourlang.Expression) -> str:
return f'{inp.function.name}({args})' return f'{inp.function.name}({args})'
if isinstance(inp, ourlang.FunctionReference):
return str(inp.function.name)
if isinstance(inp, ourlang.TupleInstantiation): if isinstance(inp, ourlang.TupleInstantiation):
args = ', '.join( args = ', '.join(
expression(arg) expression(arg)
@ -102,10 +105,6 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}' return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements: def statement(inp: ourlang.Statement) -> Statements:

View File

@ -4,11 +4,11 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
import struct import struct
from typing import List, Optional from typing import List, Optional
from . import codestyle, ourlang, prelude, wasm from . import ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
from .type3.functions import TypeVariable from .type3.functions import FunctionArgument, TypeVariable
from .type3.routers import NoRouteForTypeException, TypeApplicationRouter from .type3.routers import NoRouteForTypeException, TypeApplicationRouter
from .type3.typeclasses import Type3ClassMethod from .type3.typeclasses import Type3ClassMethod
from .type3.types import ( from .type3.types import (
@ -100,7 +100,7 @@ def type3(inp: Type3) -> wasm.WasmType:
raise NotImplementedError(type3, inp) raise NotImplementedError(type3, inp)
def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> None: def tuple_instantiation(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.TupleInstantiation) -> None:
""" """
Compile: Instantiation (allocation) of a tuple Compile: Instantiation (allocation) of a tuple
""" """
@ -150,7 +150,7 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) ->
wgn.add_statement('nop', comment='PRE') wgn.add_statement('nop', comment='PRE')
wgn.local.get(tmp_var) wgn.local.get(tmp_var)
expression(wgn, element) expression(wgn, mod, element)
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset)) wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset))
wgn.add_statement('nop', comment='POST') wgn.add_statement('nop', comment='POST')
@ -160,29 +160,29 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) ->
wgn.local.get(tmp_var) wgn.local.get(tmp_var)
def expression_subscript_bytes( def expression_subscript_bytes(
attrs: tuple[WasmGenerator, ourlang.Subscript], attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
) -> None: ) -> None:
wgn, inp = attrs wgn, mod, inp = attrs
expression(wgn, inp.varref) expression(wgn, mod, inp.varref)
expression(wgn, inp.index) expression(wgn, mod, inp.index)
wgn.call(stdlib_types.__subscript_bytes__) wgn.call(stdlib_types.__subscript_bytes__)
def expression_subscript_static_array( def expression_subscript_static_array(
attrs: tuple[WasmGenerator, ourlang.Subscript], attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
args: tuple[Type3, IntType3], args: tuple[Type3, IntType3],
) -> None: ) -> None:
wgn, inp = attrs wgn, mod, inp = attrs
el_type, el_len = args el_type, el_len = args
# OPTIMIZE: If index is a constant, we can use offset instead of multiply # OPTIMIZE: If index is a constant, we can use offset instead of multiply
# and we don't need to do the out of bounds check # and we don't need to do the out of bounds check
expression(wgn, inp.varref) expression(wgn, mod, inp.varref)
tmp_var = wgn.temp_var_i32('index') tmp_var = wgn.temp_var_i32('index')
expression(wgn, inp.index) expression(wgn, mod, inp.index)
wgn.local.tee(tmp_var) wgn.local.tee(tmp_var)
# Out of bounds check based on el_len.value # Out of bounds check based on el_len.value
@ -201,10 +201,10 @@ def expression_subscript_static_array(
wgn.add_statement(f'{mtyp}.load') wgn.add_statement(f'{mtyp}.load')
def expression_subscript_tuple( def expression_subscript_tuple(
attrs: tuple[WasmGenerator, ourlang.Subscript], attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
args: tuple[Type3, ...], args: tuple[Type3, ...],
) -> None: ) -> None:
wgn, inp = attrs wgn, mod, inp = attrs
assert isinstance(inp.index, ourlang.ConstantPrimitive) assert isinstance(inp.index, ourlang.ConstantPrimitive)
assert isinstance(inp.index.value, int) assert isinstance(inp.index.value, int)
@ -217,7 +217,7 @@ def expression_subscript_tuple(
el_type = args[inp.index.value] el_type = args[inp.index.value]
assert el_type is not None, TYPE3_ASSERTION_ERROR assert el_type is not None, TYPE3_ASSERTION_ERROR
expression(wgn, inp.varref) expression(wgn, mod, inp.varref)
if (prelude.InternalPassAsPointer, (el_type, )) in prelude.PRELUDE_TYPE_CLASS_INSTANCES_EXISTING: if (prelude.InternalPassAsPointer, (el_type, )) in prelude.PRELUDE_TYPE_CLASS_INSTANCES_EXISTING:
mtyp = 'i32' mtyp = 'i32'
@ -226,12 +226,12 @@ def expression_subscript_tuple(
wgn.add_statement(f'{mtyp}.load', f'offset={offset}') wgn.add_statement(f'{mtyp}.load', f'offset={offset}')
SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Subscript], None]() SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], None]()
SUBSCRIPT_ROUTER.add_n(prelude.bytes_, expression_subscript_bytes) SUBSCRIPT_ROUTER.add_n(prelude.bytes_, expression_subscript_bytes)
SUBSCRIPT_ROUTER.add(prelude.static_array, expression_subscript_static_array) SUBSCRIPT_ROUTER.add(prelude.static_array, expression_subscript_static_array)
SUBSCRIPT_ROUTER.add(prelude.tuple_, expression_subscript_tuple) SUBSCRIPT_ROUTER.add(prelude.tuple_, expression_subscript_tuple)
def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) -> None:
""" """
Compile: Any expression Compile: Any expression
""" """
@ -291,14 +291,14 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
wgn.i32.const(address) wgn.i32.const(address)
return return
expression(wgn, inp.variable.constant) expression(wgn, mod, inp.variable.constant)
return return
raise NotImplementedError(expression, inp.variable) raise NotImplementedError(expression, inp.variable)
if isinstance(inp, ourlang.BinaryOp): if isinstance(inp, ourlang.BinaryOp):
expression(wgn, inp.left) expression(wgn, mod, inp.left)
expression(wgn, inp.right) expression(wgn, mod, inp.right)
type_var_map: dict[TypeVariable, Type3] = {} type_var_map: dict[TypeVariable, Type3] = {}
@ -313,6 +313,10 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
type_var_map[type_var] = arg_expr.type3 type_var_map[type_var] = arg_expr.type3
continue continue
if isinstance(type_var, FunctionArgument):
# Fixed type, not part of the lookup requirements
continue
raise NotImplementedError(type_var, arg_expr.type3) raise NotImplementedError(type_var, arg_expr.type3)
router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.operator] router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.operator]
@ -321,7 +325,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
if isinstance(inp, ourlang.FunctionCall): if isinstance(inp, ourlang.FunctionCall):
for arg in inp.arguments: for arg in inp.arguments:
expression(wgn, arg) expression(wgn, mod, arg)
if isinstance(inp.function, Type3ClassMethod): if isinstance(inp.function, Type3ClassMethod):
# FIXME: Duplicate code with BinaryOp # FIXME: Duplicate code with BinaryOp
@ -338,6 +342,10 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
type_var_map[type_var] = arg_expr.type3 type_var_map[type_var] = arg_expr.type3
continue continue
if isinstance(type_var, FunctionArgument):
# Fixed type, not part of the lookup requirements
continue
raise NotImplementedError(type_var, arg_expr.type3) raise NotImplementedError(type_var, arg_expr.type3)
router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.function] router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.function]
@ -350,15 +358,24 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
wgn.add_statement('call', '${}'.format(inp.function.name)) wgn.add_statement('call', '${}'.format(inp.function.name))
return return
if isinstance(inp, ourlang.FunctionReference):
idx = mod.functions_table.get(inp.function)
if idx is None:
idx = len(mod.functions_table)
mod.functions_table[inp.function] = idx
wgn.add_statement('i32.const', str(idx), comment=inp.function.name)
return
if isinstance(inp, ourlang.TupleInstantiation): if isinstance(inp, ourlang.TupleInstantiation):
tuple_instantiation(wgn, inp) tuple_instantiation(wgn, mod, inp)
return return
if isinstance(inp, ourlang.Subscript): if isinstance(inp, ourlang.Subscript):
assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR
# Type checker guarantees we don't get routing errors # Type checker guarantees we don't get routing errors
SUBSCRIPT_ROUTER((wgn, inp, ), inp.varref.type3) SUBSCRIPT_ROUTER((wgn, mod, inp, ), inp.varref.type3)
return return
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
@ -370,111 +387,29 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
mtyp = LOAD_STORE_TYPE_MAP[member_type.name] mtyp = LOAD_STORE_TYPE_MAP[member_type.name]
expression(wgn, inp.varref) expression(wgn, mod, inp.varref)
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(calculate_member_offset( wgn.add_statement(f'{mtyp}.load', 'offset=' + str(calculate_member_offset(
inp.struct_type3.name, inp.struct_type3.application.arguments, inp.member inp.struct_type3.name, inp.struct_type3.application.arguments, inp.member
))) )))
return return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp)
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: def statement_return(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementReturn) -> None:
"""
Compile: Fold expression
"""
assert inp.type3 is not None, TYPE3_ASSERTION_ERROR
if inp.iter.type3 is not prelude.bytes_:
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
wgn.add_statement('nop', comment='acu :: u8')
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
wgn.add_statement('nop', comment='adr :: bytes*')
adr_var = wgn.temp_var_i32('fold_i32_adr')
wgn.add_statement('nop', comment='len :: i32')
len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base')
expression(wgn, inp.base)
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, inp.iter)
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)')
wgn.local.get(adr_var)
wgn.i32.load()
wgn.local.set(len_var)
wgn.add_statement('nop', comment='i = 0')
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
wgn.i32.const(0)
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i < len')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
with wgn.if_():
# From here on, adr_var is the address of byte we're referencing
# This is akin to calling stdlib_types.__subscript_bytes__
# But since we already know we are inside of bounds,
# can just bypass it and load the memory directly.
wgn.local.get(adr_var)
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
wgn.i32.add()
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='while True')
with wgn.loop():
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
wgn.local.get(acu_var)
# Get the next byte, write back the address
wgn.local.get(adr_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.tee(adr_var)
wgn.i32.load8_u()
wgn.add_statement('call', f'${inp.func.name}')
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='i = i + 1')
wgn.local.get(idx_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i >= len: break')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
wgn.br_if(0)
# return acu
wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
""" """
Compile: Return statement Compile: Return statement
""" """
expression(wgn, inp.value) expression(wgn, mod, inp.value)
wgn.return_() wgn.return_()
def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None: def statement_if(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementIf) -> None:
""" """
Compile: If statement Compile: If statement
""" """
expression(wgn, inp.test) expression(wgn, mod, inp.test)
with wgn.if_(): with wgn.if_():
for stat in inp.statements: for stat in inp.statements:
statement(wgn, stat) statement(wgn, mod, stat)
if inp.else_statements: if inp.else_statements:
raise NotImplementedError raise NotImplementedError
@ -482,16 +417,16 @@ def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None:
# for stat in inp.else_statements: # for stat in inp.else_statements:
# statement(wgn, stat) # statement(wgn, stat)
def statement(wgn: WasmGenerator, inp: ourlang.Statement) -> None: def statement(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Statement) -> None:
""" """
Compile: any statement Compile: any statement
""" """
if isinstance(inp, ourlang.StatementReturn): if isinstance(inp, ourlang.StatementReturn):
statement_return(wgn, inp) statement_return(wgn, mod, inp)
return return
if isinstance(inp, ourlang.StatementIf): if isinstance(inp, ourlang.StatementIf):
statement_if(wgn, inp) statement_if(wgn, mod, inp)
return return
if isinstance(inp, ourlang.StatementPass): if isinstance(inp, ourlang.StatementPass):
@ -522,7 +457,7 @@ def import_(inp: ourlang.Function) -> wasm.Import:
type3(inp.returns_type3) type3(inp.returns_type3)
) )
def function(inp: ourlang.Function) -> wasm.Function: def function(mod: ourlang.Module, inp: ourlang.Function) -> wasm.Function:
""" """
Compile: function Compile: function
""" """
@ -534,7 +469,7 @@ def function(inp: ourlang.Function) -> wasm.Function:
_generate_struct_constructor(wgn, inp) _generate_struct_constructor(wgn, inp)
else: else:
for stat in inp.statements: for stat in inp.statements:
statement(wgn, stat) statement(wgn, mod, stat)
return wasm.Function( return wasm.Function(
inp.name, inp.name,
@ -724,26 +659,32 @@ def module(inp: ourlang.Module) -> wasm.Module:
stdlib_alloc.__find_free_block__, stdlib_alloc.__find_free_block__,
stdlib_alloc.__alloc__, stdlib_alloc.__alloc__,
stdlib_types.__alloc_bytes__, stdlib_types.__alloc_bytes__,
stdlib_types.__subscript_bytes__, # stdlib_types.__subscript_bytes__,
stdlib_types.__u32_ord_min__, # stdlib_types.__u32_ord_min__,
stdlib_types.__u64_ord_min__, # stdlib_types.__u64_ord_min__,
stdlib_types.__i32_ord_min__, # stdlib_types.__i32_ord_min__,
stdlib_types.__i64_ord_min__, # stdlib_types.__i64_ord_min__,
stdlib_types.__u32_ord_max__, # stdlib_types.__u32_ord_max__,
stdlib_types.__u64_ord_max__, # stdlib_types.__u64_ord_max__,
stdlib_types.__i32_ord_max__, # stdlib_types.__i32_ord_max__,
stdlib_types.__i64_ord_max__, # stdlib_types.__i64_ord_max__,
stdlib_types.__i32_intnum_abs__, # stdlib_types.__i32_intnum_abs__,
stdlib_types.__i64_intnum_abs__, # stdlib_types.__i64_intnum_abs__,
stdlib_types.__u32_pow2__, # stdlib_types.__u32_pow2__,
stdlib_types.__u8_rotl__, # stdlib_types.__u8_rotl__,
stdlib_types.__u8_rotr__, # stdlib_types.__u8_rotr__,
] + [ ] + [
function(x) function(inp, x)
for x in inp.functions.values() for x in inp.functions.values()
if not x.imported if not x.imported
] ]
# Do this after rendering the functions since that's what populates the tables
result.table = {
v: k.name
for k, v in inp.functions_table.items()
}
return result return result
def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None:

View File

@ -1,7 +1,6 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
import enum
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from . import prelude from . import prelude
@ -161,6 +160,18 @@ class FunctionCall(Expression):
self.function = function self.function = function
self.arguments = [] self.arguments = []
class FunctionReference(Expression):
"""
An function reference expression within a statement
"""
__slots__ = ('function', )
function: 'Function'
def __init__(self, function: 'Function') -> None:
super().__init__()
self.function = function
class TupleInstantiation(Expression): class TupleInstantiation(Expression):
""" """
Instantiation a tuple Instantiation a tuple
@ -207,36 +218,6 @@ class AccessStructMember(Expression):
self.struct_type3 = struct_type3 self.struct_type3 = struct_type3
self.member = member self.member = member
class Fold(Expression):
"""
A (left or right) fold
"""
class Direction(enum.Enum):
"""
Which direction to fold in
"""
LEFT = 0
RIGHT = 1
dir: Direction
func: 'Function'
base: Expression
iter: Expression
def __init__(
self,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__()
self.dir = dir_
self.func = func
self.base = base
self.iter = iter_
class Statement: class Statement:
""" """
A statement within a function A statement within a function
@ -397,13 +378,14 @@ class Module:
""" """
A module is a file and consists of functions A module is a file and consists of functions
""" """
__slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', ) __slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'functions_table', 'operators', )
data: ModuleData data: ModuleData
types: dict[str, Type3] types: dict[str, Type3]
struct_definitions: Dict[str, StructDefinition] struct_definitions: Dict[str, StructDefinition]
constant_defs: Dict[str, ModuleConstantDef] constant_defs: Dict[str, ModuleConstantDef]
functions: Dict[str, Function] functions: Dict[str, Function]
functions_table: dict[Function, int]
operators: Dict[str, Type3ClassMethod] operators: Dict[str, Type3ClassMethod]
def __init__(self) -> None: def __init__(self) -> None:
@ -413,3 +395,4 @@ class Module:
self.constant_defs = {} self.constant_defs = {}
self.functions = {} self.functions = {}
self.operators = {} self.operators = {}
self.functions_table = {}

View File

@ -14,10 +14,10 @@ from .ourlang import (
ConstantStruct, ConstantStruct,
ConstantTuple, ConstantTuple,
Expression, Expression,
Fold,
Function, Function,
FunctionCall, FunctionCall,
FunctionParam, FunctionParam,
FunctionReference,
Module, Module,
ModuleConstantDef, ModuleConstantDef,
ModuleDataBlock, ModuleDataBlock,
@ -446,6 +446,9 @@ class OurVisitor:
cdef = module.constant_defs[node.id] cdef = module.constant_defs[node.id]
return VariableReference(cdef) return VariableReference(cdef)
if node.id in module.functions:
return FunctionReference(module.functions[node.id])
_raise_static_error(node, f'Undefined variable {node.id}') _raise_static_error(node, f'Undefined variable {node.id}')
if isinstance(node, ast.Tuple): if isinstance(node, ast.Tuple):
@ -462,7 +465,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as expr in FunctionDef') raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]: def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
if node.keywords: if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet? _raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -475,28 +478,6 @@ class OurVisitor:
if node.func.id in PRELUDE_METHODS: if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id] func = PRELUDE_METHODS[node.func.id]
elif node.func.id == 'foldl':
if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic, you cannot return a function
subnode = node.args[0]
if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id]
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
return Fold(
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
)
else: else:
if node.func.id not in module.functions: if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function') _raise_static_error(node, 'Call to undefined function')

View File

@ -20,6 +20,7 @@ from ..type3.types import (
Type3, Type3,
TypeApplication_Nullary, TypeApplication_Nullary,
TypeConstructor_Base, TypeConstructor_Base,
TypeConstructor_Function,
TypeConstructor_StaticArray, TypeConstructor_StaticArray,
TypeConstructor_Struct, TypeConstructor_Struct,
TypeConstructor_Tuple, TypeConstructor_Tuple,
@ -186,6 +187,16 @@ It should be applied with zero or more arguments. It has a compile time
determined length, and each argument can be different. determined length, and each argument can be different.
""" """
def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None:
pass # ? instance_type_class(InternalPassAsPointer, typ)
function = TypeConstructor_Function('function', on_create=fn_on_create)
"""
This is a function.
It should be applied with one or more arguments. The last argument is the 'return' type.
"""
def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None: def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None:
instance_type_class(InternalPassAsPointer, typ) instance_type_class(InternalPassAsPointer, typ)
@ -574,12 +585,16 @@ instance_type_class(Promotable, f32, f64, methods={
Foldable = Type3Class('Foldable', (t, ), methods={ Foldable = Type3Class('Foldable', (t, ), methods={
'sum': [t(a), a], 'sum': [t(a), a],
'foldl': [[b, a, b], b, t(a), b],
'foldr': [[a, b, b], b, t(a), b],
}, operators={}, additional_context={ }, operators={}, additional_context={
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))], 'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
}) })
instance_type_class(Foldable, static_array, methods={ instance_type_class(Foldable, static_array, methods={
'sum': stdtypes.static_array_sum, 'sum': stdtypes.static_array_sum,
'foldl': stdtypes.static_array_foldl,
'foldr': stdtypes.static_array_foldr,
}) })
PRELUDE_TYPE_CLASSES = { PRELUDE_TYPE_CLASSES = {

View File

@ -4,7 +4,7 @@ stdlib: Standard types that are not wasm primitives
from phasm.stdlib import alloc from phasm.stdlib import alloc
from phasm.type3.routers import TypeVariableLookup from phasm.type3.routers import TypeVariableLookup
from phasm.type3.types import IntType3, Type3 from phasm.type3.types import IntType3, Type3
from phasm.wasmgenerator import Generator, func_wrapper from phasm.wasmgenerator import Generator, VarType_Base, func_wrapper
from phasm.wasmgenerator import VarType_i32 as i32 from phasm.wasmgenerator import VarType_i32 as i32
from phasm.wasmgenerator import VarType_i64 as i64 from phasm.wasmgenerator import VarType_i64 as i64
@ -1081,9 +1081,17 @@ def f32_f64_demote(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map del tv_map
g.f32.demote_f64() g.f32.demote_f64()
def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: def static_array_sum(g: Generator, tvl: TypeVariableLookup) -> None:
assert len(tv_map) == 1 tv_map, tc_map = tvl
sa_type, sa_len = next(iter(tv_map.values()))
tvn_map = {
x.name: y
for x, y in tv_map.items()
}
sa_type = tvn_map['a']
sa_len = tvn_map['a*']
assert isinstance(sa_type, Type3) assert isinstance(sa_type, Type3)
assert isinstance(sa_len, IntType3) assert isinstance(sa_len, IntType3)
@ -1166,7 +1174,7 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
g.nop(comment='Add array value') g.nop(comment='Add array value')
g.local.get(sum_adr) g.local.get(sum_adr)
g.add_statement(f'{sa_type_mtyp}.load') g.add_statement(f'{sa_type_mtyp}.load')
sa_type_add_gen(g, {}) sa_type_add_gen(g, ({}, {}, ))
# adr = adr + sa_type_alloc_size # adr = adr + sa_type_alloc_size
# Stack: [sum] -> [sum] # Stack: [sum] -> [sum]
@ -1185,3 +1193,253 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]') g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]')
# End result: [sum] # End result: [sum]
def static_array_foldl(g: Generator, tvl: TypeVariableLookup) -> None:
tv_map, tc_map = tvl
tvn_map = {
x.name: y
for x, y in tv_map.items()
}
sa_type = tvn_map['a']
sa_len = tvn_map['a*']
res_type = tvn_map['b']
assert isinstance(sa_type, Type3)
assert isinstance(sa_len, IntType3)
assert isinstance(res_type, Type3)
if sa_len.value < 1:
raise NotImplementedError('Default value in case foldl is empty')
# FIXME: We should probably use LOAD_STORE_TYPE_MAP for this?
mtyp_map = {
'u32': 'i32',
'u64': 'i64',
'i32': 'i32',
'i64': 'i64',
'f32': 'f32',
'f64': 'f64',
}
mtyp_f_map: dict[str, type[VarType_Base]] = {
'i32': i32,
'i64': i64,
}
# FIXME: We should probably use calc_alloc_size for this?
type_var_size_map = {
'u32': 4,
'u64': 8,
'i32': 4,
'i64': 8,
'f32': 4,
'f64': 8,
}
# By default, constructed types are passed as pointers
# FIXME: We don't know what add function to call
sa_type_mtyp = mtyp_map.get(sa_type.name, 'i32')
sa_type_alloc_size = type_var_size_map.get(sa_type.name, 4)
res_type_mtyp = mtyp_map.get(res_type.name, 'i32')
res_type_mtyp_f = mtyp_f_map[res_type_mtyp]
# Definitions
fold_adr = g.temp_var(i32('fold_adr'))
fold_stop = g.temp_var(i32('fold_stop'))
fold_init = g.temp_var(res_type_mtyp_f('fold_init'))
fold_func = g.temp_var(i32('fold_func'))
with g.block(params=['i32', res_type_mtyp, 'i32'], result=res_type_mtyp, comment=f'foldl a={sa_type.name} a*={sa_len.value} b={res_type.name}'):
# Stack: [[a, b, b], b, t(a), b]
# t(a) == sa_type[sa_len]
# Stack: [fn*, b, sa*]
# adr = {address of what's currently on stack}
# Stack: [fn*, b, sa*] -> [fn*, b]
g.local.set(fold_adr)
# Stack: [fn*, b] -> [fn*]
g.local.set(fold_init)
# Stack: [fn*] -> []
g.local.set(fold_func)
# stop = adr + ar_len * sa_type_alloc_size
# Stack: []
g.nop(comment='Calculate address at which to stop looping')
g.local.get(fold_adr)
g.i32.const(sa_len.value * sa_type_alloc_size)
g.i32.add()
g.local.set(fold_stop)
# Stack: [] -> [b]
g.nop(comment='Get the init value and first array value as starting point')
g.local.get(fold_init)
# Stack: [b] -> [b, *a]
g.local.get(fold_adr)
# Stack: [b] -> [b, a]
g.add_statement(f'{sa_type_mtyp}.load')
g.nop(comment='Call the fold function')
g.local.get(fold_func)
g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})')
# adr = adr + sa_type_alloc_size
# Stack: [b] -> [b]
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_alloc_size)
g.i32.add()
g.local.set(fold_adr)
if sa_len.value > 1:
with g.loop(params=[sa_type_mtyp], result=sa_type_mtyp):
# Stack: [b] -> [b, a]
g.nop(comment='Add array value')
g.local.get(fold_adr)
g.add_statement(f'{sa_type_mtyp}.load')
# Stack [b, a] -> b
g.nop(comment='Call the fold function')
g.local.get(fold_func)
g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})')
# adr = adr + sa_type_alloc_size
# Stack: [fold] -> [fold]
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_alloc_size)
g.i32.add()
g.local.tee(fold_adr)
# loop if adr < stop
g.nop(comment='Check if address exceeds array bounds')
g.local.get(fold_stop)
g.i32.lt_u()
g.br_if(0)
# else: just one value, don't need to loop
# Stack: [b]
def static_array_foldr(g: Generator, tvl: TypeVariableLookup) -> None:
tv_map, tc_map = tvl
tvn_map = {
x.name: y
for x, y in tv_map.items()
}
sa_type = tvn_map['a']
sa_len = tvn_map['a*']
res_type = tvn_map['b']
assert isinstance(sa_type, Type3)
assert isinstance(sa_len, IntType3)
assert isinstance(res_type, Type3)
if sa_len.value < 1:
raise NotImplementedError('Default value in case foldl is empty')
# FIXME: We should probably use LOAD_STORE_TYPE_MAP for this?
mtyp_map = {
'u32': 'i32',
'u64': 'i64',
'i32': 'i32',
'i64': 'i64',
'f32': 'f32',
'f64': 'f64',
}
mtyp_f_map: dict[str, type[VarType_Base]] = {
'i32': i32,
'i64': i64,
}
# FIXME: We should probably use calc_alloc_size for this?
type_var_size_map = {
'u32': 4,
'u64': 8,
'i32': 4,
'i64': 8,
'f32': 4,
'f64': 8,
}
# By default, constructed types are passed as pointers
# FIXME: We don't know what add function to call
sa_type_mtyp = mtyp_map.get(sa_type.name, 'i32')
sa_type_alloc_size = type_var_size_map.get(sa_type.name, 4)
res_type_mtyp = mtyp_map.get(res_type.name, 'i32')
res_type_mtyp_f = mtyp_f_map[res_type_mtyp]
# Definitions
fold_adr = g.temp_var(i32('fold_adr'))
fold_stop = g.temp_var(i32('fold_stop'))
fold_init = g.temp_var(res_type_mtyp_f('fold_init'))
fold_func = g.temp_var(i32('fold_func'))
with g.block(params=['i32', res_type_mtyp, 'i32'], result=res_type_mtyp, comment=f'foldl a={sa_type.name} a*={sa_len.value} b={res_type.name}'):
# Stack: [[a, b, b], b, t(a), b]
# t(a) == sa_type[sa_len]
# Stack: [fn*, b, sa*]
# adr = {address of what's currently on stack}
# Stack: [fn*, b, sa*] -> [fn*, b]
g.local.set(fold_adr)
# Stack: [fn*, b] -> [fn*]
g.local.set(fold_init)
# Stack: [fn*] -> []
g.local.set(fold_func)
# stop = adr + ar_len * sa_type_alloc_size
# Stack: []
g.nop(comment='Calculate address at which to stop looping')
g.local.get(fold_adr)
g.i32.const(sa_len.value * sa_type_alloc_size)
g.i32.add()
g.local.set(fold_stop)
# Stack: [] -> [b]
g.nop(comment='Get the init value and first array value as starting point')
g.local.get(fold_init)
# Stack: [b] -> [b, *a]
g.local.get(fold_adr)
# Stack: [b] -> [b, a]
g.add_statement(f'{sa_type_mtyp}.load')
g.nop(comment='Call the fold function')
g.local.get(fold_func)
g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})')
# adr = adr + sa_type_alloc_size
# Stack: [b] -> [b]
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_alloc_size)
g.i32.add()
g.local.set(fold_adr)
if sa_len.value > 1:
with g.loop(params=[sa_type_mtyp], result=sa_type_mtyp):
# Stack: [b] -> [b, a]
g.nop(comment='Add array value')
g.local.get(fold_adr)
g.add_statement(f'{sa_type_mtyp}.load')
# Stack [b, a] -> b
g.nop(comment='Call the fold function')
g.local.get(fold_func)
g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})')
# adr = adr + sa_type_alloc_size
# Stack: [fold] -> [fold]
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_alloc_size)
g.i32.add()
g.local.tee(fold_adr)
# loop if adr < stop
g.nop(comment='Check if address exceeds array bounds')
g.local.get(fold_stop)
g.i32.lt_u()
g.br_if(0)
# else: just one value, don't need to loop
# Stack: [b]

View File

@ -6,6 +6,7 @@ These need to be resolved before the program can be compiled.
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from .. import ourlang, prelude from .. import ourlang, prelude
from .functions import FunctionArgument, TypeVariable
from .placeholders import PlaceholderForType, Type3OrPlaceholder from .placeholders import PlaceholderForType, Type3OrPlaceholder
from .routers import NoRouteForTypeException, TypeApplicationRouter from .routers import NoRouteForTypeException, TypeApplicationRouter
from .typeclasses import Type3Class from .typeclasses import Type3Class
@ -158,7 +159,7 @@ class SameTypeConstraint(ConstraintBase):
return ( return (
' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))), ' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))),
{ {
't' + str(idx): typ 't' + str(idx): typ.name if isinstance(typ, ourlang.Function) else typ
for idx, typ in enumerate(self.type_list) for idx, typ in enumerate(self.type_list)
}, },
) )
@ -181,7 +182,7 @@ class SameTypeArgumentConstraint(ConstraintBase):
self.arg_var = arg_var self.arg_var = arg_var
def check(self) -> CheckResult: def check(self) -> CheckResult:
if self.tc_var.resolve_as is None or self.arg_var.resolve_as is None: if self.tc_var.resolve_as is None:
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
tc_typ = self.tc_var.resolve_as tc_typ = self.tc_var.resolve_as
@ -201,13 +202,84 @@ class SameTypeArgumentConstraint(ConstraintBase):
# FIXME: This feels sketchy. Shouldn't the type variable # FIXME: This feels sketchy. Shouldn't the type variable
# have the exact same number as arguments? # have the exact same number as arguments?
if isinstance(tc_typ.application, TypeApplication_TypeInt): if isinstance(tc_typ.application, TypeApplication_TypeInt):
if tc_typ.application.arguments[0] == arg_typ: return [SameTypeConstraint(
return None tc_typ.application.arguments[0],
self.arg_var,
return Error(f'{tc_typ.application.arguments[0]:s} must be {arg_typ:s} instead') comment=self.comment,
)]
raise NotImplementedError(tc_typ, arg_typ) raise NotImplementedError(tc_typ, arg_typ)
def human_readable(self) -> HumanReadableRet:
return (
'{tc_var}` == {arg_var}',
{
'tc_var': self.tc_var if self.tc_var.resolve_as is None else self.tc_var,
'arg_var': self.arg_var if self.arg_var.resolve_as is None else self.arg_var,
},
)
class SameFunctionArgumentConstraint(ConstraintBase):
__slots__ = ('type3', 'func_arg', 'type_var_map', )
type3: PlaceholderForType
func_arg: FunctionArgument
type_var_map: dict[TypeVariable, PlaceholderForType]
def __init__(self, type3: PlaceholderForType, func_arg: FunctionArgument, type_var_map: dict[TypeVariable, PlaceholderForType], *, comment: str) -> None:
super().__init__(comment=comment)
self.type3 = type3
self.func_arg = func_arg
self.type_var_map = type_var_map
def check(self) -> CheckResult:
if self.type3.resolve_as is None:
return RequireTypeSubstitutes()
typ = self.type3.resolve_as
if isinstance(typ.application, TypeApplication_Nullary):
return Error(f'{typ:s} must be a function instead')
if not isinstance(typ.application, TypeApplication_TypeStar):
return Error(f'{typ:s} must be a function instead')
type_var_map = {
x: y.resolve_as
for x, y in self.type_var_map.items()
if y.resolve_as is not None
}
exp_type_arg_list = [
tv if isinstance(tv, Type3) else type_var_map[tv]
for tv in self.func_arg.args
if isinstance(tv, Type3) or tv in type_var_map
]
print('self.func_arg.args', self.func_arg.args)
print('exp_type_arg_list', exp_type_arg_list)
if len(exp_type_arg_list) != len(self.func_arg.args):
return RequireTypeSubstitutes()
return [
SameTypeConstraint(
typ,
prelude.function(*exp_type_arg_list),
comment=self.comment,
)
]
def human_readable(self) -> HumanReadableRet:
return (
'{type3} == {func_arg}',
{
'type3': self.type3,
'func_arg': self.func_arg.name,
},
)
class TupleMatchConstraint(ConstraintBase): class TupleMatchConstraint(ConstraintBase):
__slots__ = ('exp_type', 'args', ) __slots__ = ('exp_type', 'args', )
@ -264,14 +336,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
__slots__ = ('context', 'type_class3', 'types', ) __slots__ = ('context', 'type_class3', 'types', )
context: Context context: Context
type_class3: Union[str, Type3Class] type_class3: Type3Class
types: list[Type3OrPlaceholder] types: list[Type3OrPlaceholder]
DATA = { def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
'bytes': {'Foldable'},
}
def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
self.context = context self.context = context
@ -299,13 +367,9 @@ class MustImplementTypeClassConstraint(ConstraintBase):
assert len(typ_list) == len(self.types) assert len(typ_list) == len(self.types)
if isinstance(self.type_class3, Type3Class):
key = (self.type_class3, tuple(typ_list), ) key = (self.type_class3, tuple(typ_list), )
if key in self.context.type_class_instances_existing: if key in self.context.type_class_instances_existing:
return None return None
else:
if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()):
return None
typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name
typ_name_list = ' '.join(x.name for x in typ_list) typ_name_list = ' '.join(x.name for x in typ_list)

View File

@ -12,12 +12,14 @@ from .constraints import (
Context, Context,
LiteralFitsConstraint, LiteralFitsConstraint,
MustImplementTypeClassConstraint, MustImplementTypeClassConstraint,
SameFunctionArgumentConstraint,
SameTypeArgumentConstraint, SameTypeArgumentConstraint,
SameTypeConstraint, SameTypeConstraint,
TupleMatchConstraint, TupleMatchConstraint,
) )
from .functions import ( from .functions import (
Constraint_TypeClassInstanceExists, Constraint_TypeClassInstanceExists,
FunctionArgument,
FunctionSignature, FunctionSignature,
TypeVariable, TypeVariable,
TypeVariableApplication_Unary, TypeVariableApplication_Unary,
@ -111,6 +113,33 @@ def _expression_function_call(
raise NotImplementedError(constraint) raise NotImplementedError(constraint)
func_var_map = {
x: PlaceholderForType([])
for x in signature.args
if isinstance(x, FunctionArgument)
}
# If some of the function arguments are functions,
# we need to deal with those separately.
for sig_arg in signature.args:
if not isinstance(sig_arg, FunctionArgument):
continue
# Ensure that for all type variables in the function
# there are also type variables available
for func_arg in sig_arg.args:
if isinstance(func_arg, Type3):
continue
type_var_map.setdefault(func_arg, PlaceholderForType([]))
yield SameFunctionArgumentConstraint(
func_var_map[sig_arg],
sig_arg,
type_var_map,
comment=f'Ensure `{sig_arg.name}` matches in {signature}',
)
# If some of the function arguments are type constructors, # If some of the function arguments are type constructors,
# we need to deal with those separately. # we need to deal with those separately.
# That is, given `foo :: t a -> a` we need to ensure # That is, given `foo :: t a -> a` we need to ensure
@ -120,6 +149,9 @@ def _expression_function_call(
# Not a type variable at all # Not a type variable at all
continue continue
if isinstance(sig_arg, FunctionArgument):
continue
if sig_arg.application.constructor is None: if sig_arg.application.constructor is None:
# Not a type variable for a type constructor # Not a type variable for a type constructor
continue continue
@ -150,9 +182,20 @@ def _expression_function_call(
yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment) yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment)
continue continue
if isinstance(sig_part, FunctionArgument):
yield SameTypeConstraint(func_var_map[sig_part], arg_placeholders[arg_expr], comment=comment)
continue
raise NotImplementedError(sig_part) raise NotImplementedError(sig_part)
return return
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
yield SameTypeConstraint(
prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3),
phft,
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
)
def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator: def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator:
if isinstance(inp, ourlang.Constant): if isinstance(inp, ourlang.Constant):
yield from constant(ctx, inp, phft) yield from constant(ctx, inp, phft)
@ -171,6 +214,10 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
yield from expression_function_call(ctx, inp, phft) yield from expression_function_call(ctx, inp, phft)
return return
if isinstance(inp, ourlang.FunctionReference):
yield from expression_function_reference(ctx, inp, phft)
return
if isinstance(inp, ourlang.TupleInstantiation): if isinstance(inp, ourlang.TupleInstantiation):
r_type = [] r_type = []
for arg in inp.elements: for arg in inp.elements:
@ -209,19 +256,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
return return
if isinstance(inp, ourlang.Fold):
base_phft = PlaceholderForType([inp.base])
iter_phft = PlaceholderForType([inp.iter])
yield from expression(ctx, inp.base, base_phft)
yield from expression(ctx, inp.iter, iter_phft)
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft,
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft])
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:

View File

@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List, Union from __future__ import annotations
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List
if TYPE_CHECKING: if TYPE_CHECKING:
from .typeclasses import Type3Class from .typeclasses import Type3Class
@ -155,15 +157,29 @@ class TypeVariableContext:
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TypeVariableContext({self.constraints!r})' return f'TypeVariableContext({self.constraints!r})'
class FunctionArgument:
__slots__ = ('args', 'name', )
args: list[Type3 | TypeVariable]
name: str
def __init__(self, args: list[Type3 | TypeVariable]) -> None:
self.args = args
self.name = '(' + ' -> '.join(x.name for x in args) + ')'
class FunctionSignature: class FunctionSignature:
__slots__ = ('context', 'args', ) __slots__ = ('context', 'args', )
context: TypeVariableContext context: TypeVariableContext
args: List[Union['Type3', TypeVariable]] args: List[Type3 | TypeVariable | FunctionArgument]
def __init__(self, context: TypeVariableContext, args: Iterable[Union['Type3', TypeVariable]]) -> None: def __init__(self, context: TypeVariableContext, args: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]) -> None:
self.context = context.__copy__() self.context = context.__copy__()
self.args = list(args) self.args = list(
FunctionArgument(x) if isinstance(x, list) else x
for x in args
)
def __str__(self) -> str: def __str__(self) -> str:
return str(self.context) + ' -> '.join(x.name for x in self.args) return str(self.context) + ' -> '.join(x.name for x in self.args)

View File

@ -3,6 +3,7 @@ from typing import Any, Callable
from .functions import ( from .functions import (
TypeConstructorVariable, TypeConstructorVariable,
TypeVariable, TypeVariable,
TypeVariableApplication_Nullary,
TypeVariableApplication_Unary, TypeVariableApplication_Unary,
) )
from .typeclasses import Type3ClassArgs from .typeclasses import Type3ClassArgs
@ -54,7 +55,10 @@ class TypeApplicationRouter[S, R]:
raise NoRouteForTypeException(arg0, typ) raise NoRouteForTypeException(arg0, typ)
TypeVariableLookup = dict[TypeVariable, tuple[KindArgument, ...]] TypeVariableLookup = tuple[
dict[TypeVariable, KindArgument],
dict[TypeConstructorVariable, TypeConstructor_Base[Any]],
]
class TypeClassArgsRouter[S, R]: class TypeClassArgsRouter[S, R]:
""" """
@ -89,11 +93,12 @@ class TypeClassArgsRouter[S, R]:
def __call__(self, arg0: S, tv_map: dict[TypeVariable, Type3]) -> R: def __call__(self, arg0: S, tv_map: dict[TypeVariable, Type3]) -> R:
key: list[Type3 | TypeConstructor_Base[Any]] = [] key: list[Type3 | TypeConstructor_Base[Any]] = []
arguments: TypeVariableLookup = {} arguments: TypeVariableLookup = (dict(tv_map), {}, )
for tc_arg in self.args: for tc_arg in self.args:
if isinstance(tc_arg, TypeVariable): if isinstance(tc_arg, TypeVariable):
key.append(tv_map[tc_arg]) key.append(tv_map[tc_arg])
arguments[0][tc_arg] = tv_map[tc_arg]
continue continue
for tvar, typ in tv_map.items(): for tvar, typ in tv_map.items():
@ -102,16 +107,24 @@ class TypeClassArgsRouter[S, R]:
continue continue
key.append(typ.application.constructor) key.append(typ.application.constructor)
arguments[1][tc_arg] = typ.application.constructor
if isinstance(tvar.application, TypeVariableApplication_Unary): if isinstance(tvar.application, TypeVariableApplication_Unary):
# FIXME: This feels sketchy. Shouldn't the type variable # FIXME: This feels sketchy. Shouldn't the type variable
# have the exact same number as arguments? # have the exact same number as arguments?
if isinstance(typ.application, TypeApplication_TypeInt): if isinstance(typ.application, TypeApplication_TypeInt):
arguments[tvar.application.arguments] = typ.application.arguments sa_type, sa_len = typ.application.arguments
sa_type_tv = tvar.application.arguments
sa_len_tv = TypeVariable(sa_type_tv.name + '*', TypeVariableApplication_Nullary(None, None))
arguments[0][sa_type_tv] = sa_type
arguments[0][sa_len_tv] = sa_len
continue continue
raise NotImplementedError(tvar.application, typ.application) raise NotImplementedError(tvar.application, typ.application)
continue
t_helper = self.data.get(tuple(key)) t_helper = self.data.get(tuple(key))
if t_helper is not None: if t_helper is not None:
return t_helper(arg0, arguments) return t_helper(arg0, arguments)

View File

@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Mapping, Optional, Union from typing import Dict, Iterable, List, Mapping, Optional
from .functions import ( from .functions import (
Constraint_TypeClassInstanceExists, Constraint_TypeClassInstanceExists,
@ -42,8 +42,8 @@ class Type3Class:
self, self,
name: str, name: str,
args: Type3ClassArgs, args: Type3ClassArgs,
methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]], methods: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]], operators: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
inherited_classes: Optional[List['Type3Class']] = None, inherited_classes: Optional[List['Type3Class']] = None,
additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None, additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None,
) -> None: ) -> None:
@ -71,19 +71,23 @@ class Type3Class:
return self.name return self.name
def _create_signature( def _create_signature(
method_arg_list: Iterable[Type3 | TypeVariable], method_arg_list: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]],
type_class3: Type3Class, type_class3: Type3Class,
) -> FunctionSignature: ) -> FunctionSignature:
context = TypeVariableContext() context = TypeVariableContext()
if not isinstance(type_class3.args[0], TypeConstructorVariable): if not isinstance(type_class3.args[0], TypeConstructorVariable):
context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args)) context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args))
signature_args: list[Type3 | TypeVariable] = [] signature_args: list[Type3 | TypeVariable | list[Type3 | TypeVariable]] = []
for method_arg in method_arg_list: for method_arg in method_arg_list:
if isinstance(method_arg, Type3): if isinstance(method_arg, Type3):
signature_args.append(method_arg) signature_args.append(method_arg)
continue continue
if isinstance(method_arg, list):
signature_args.append(method_arg)
continue
if isinstance(method_arg, TypeVariable): if isinstance(method_arg, TypeVariable):
type_constructor = method_arg.application.constructor type_constructor = method_arg.application.constructor
if type_constructor is None: if type_constructor is None:

View File

@ -239,6 +239,10 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar):
def make_name(self, key: Tuple[Type3, ...]) -> str: def make_name(self, key: Tuple[Type3, ...]) -> str:
return '(' + ', '.join(x.name for x in key) + ', )' return '(' + ', '.join(x.name for x in key) + ', )'
class TypeConstructor_Function(TypeConstructor_TypeStar):
def make_name(self, key: Tuple[Type3, ...]) -> str:
return '(' + ' -> '.join(x.name for x in key) + ')'
class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]): class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]):
""" """
Constructs struct types Constructs struct types

View File

@ -187,14 +187,17 @@ class Module(WatSerializable):
def __init__(self) -> None: def __init__(self) -> None:
self.imports: List[Import] = [] self.imports: List[Import] = []
self.functions: List[Function] = [] self.functions: List[Function] = []
self.table: dict[int, str] = {}
self.memory = ModuleMemory() self.memory = ModuleMemory()
def to_wat(self) -> str: def to_wat(self) -> str:
""" """
Generates the text version Generates the text version
""" """
return '(module\n {}\n {}\n {})\n'.format( return '(module\n {}\n {}\n {}\n {}\n {})\n'.format(
'\n '.join(x.to_wat() for x in self.imports), '\n '.join(x.to_wat() for x in self.imports),
f'(table {len(self.table)} funcref)',
'\n '.join(f'(elem (i32.const {k}) ${v})' for k, v in self.table.items()),
self.memory.to_wat(), self.memory.to_wat(),
'\n '.join(x.to_wat() for x in self.functions), '\n '.join(x.to_wat() for x in self.functions),
) )

View File

@ -170,11 +170,12 @@ class Generator_Local:
self.generator.add_statement('local.tee', variable.name_ref, comment=comment) self.generator.add_statement('local.tee', variable.name_ref, comment=comment)
class GeneratorBlock: class GeneratorBlock:
def __init__(self, generator: 'Generator', name: str, params: Iterable[str] = (), result: str | None = None) -> None: def __init__(self, generator: 'Generator', name: str, params: Iterable[str] = (), result: str | None = None, comment: str | None = None) -> None:
self.generator = generator self.generator = generator
self.name = name self.name = name
self.params = params self.params = params
self.result = result self.result = result
self.comment = comment
def __enter__(self) -> None: def __enter__(self) -> None:
stmt = self.name stmt = self.name
@ -186,7 +187,7 @@ class GeneratorBlock:
if self.result: if self.result:
stmt = f'{stmt} (result {self.result})' stmt = f'{stmt} (result {self.result})'
self.generator.add_statement(stmt) self.generator.add_statement(stmt, comment=self.comment)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if not exc_type: if not exc_type:
@ -208,7 +209,7 @@ class Generator:
# 2.4.5 Control Instructions # 2.4.5 Control Instructions
self.nop = functools.partial(self.add_statement, 'nop') self.nop = functools.partial(self.add_statement, 'nop')
self.unreachable = functools.partial(self.add_statement, 'unreachable') self.unreachable = functools.partial(self.add_statement, 'unreachable')
# block self.block = functools.partial(GeneratorBlock, self, 'block')
self.loop = functools.partial(GeneratorBlock, self, 'loop') self.loop = functools.partial(GeneratorBlock, self, 'loop')
self.if_ = functools.partial(GeneratorBlock, self, 'if') self.if_ = functools.partial(GeneratorBlock, self, 'if')
# br # br

View File

@ -1,61 +0,0 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
def test_foldl_1():
code_py = """
def u8_or(l: u8, r: u8) -> u8:
return l | r
@exported
def testEntry(b: bytes) -> u8:
return foldl(u8_or, 128, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'')
assert 128 == result.returned_value
result = suite.run_code(b'\x80')
assert 128 == result.returned_value
result = suite.run_code(b'\x80\x40')
assert 192 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10')
assert 240 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
assert 255 == result.returned_value
@pytest.mark.integration_test
def test_foldl_2():
code_py = """
def xor(l: u8, r: u8) -> u8:
return l ^ r
@exported
def testEntry(a: bytes, b: bytes) -> u8:
return foldl(xor, 0, a) ^ foldl(xor, 0, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'\x55\x0F', b'\x33\x80')
assert 233 == result.returned_value
@pytest.mark.integration_test
def test_foldl_3():
code_py = """
def xor(l: u32, r: u8) -> u32:
return l ^ extend(r)
@exported
def testEntry(a: bytes) -> u32:
return foldl(xor, 0, a)
"""
suite = Suite(code_py)
result = suite.run_code(b'\x55\x0F\x33\x80')
assert 233 == result.returned_value

View File

@ -36,6 +36,118 @@ def testEntry(x: Foo[4]) -> Foo:
with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'): with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.parametrize('length', [1, 5, 13])
@pytest.mark.parametrize('direction', ['foldl', 'foldr'])
def test_foldable_foldl_foldr_size(direction, length):
code_py = f"""
def u64_add(l: u64, r: u64) -> u64:
return l + r
@exported
def testEntry(b: u64[{length}]) -> u64:
return {direction}(u64_add, 100, b)
"""
suite = Suite(code_py)
in_put = tuple(range(1, length + 1))
result = suite.run_code(in_put)
assert (100 + sum(in_put)) == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('direction', ['foldr'])
def test_foldable_foldl_foldr_compounded_type(direction):
code_py = f"""
def combine_foldl(b: u64, a: (u32, u32, )) -> u64:
return extend(a[0] * a[1]) + b
def combine_foldr(a: (u32, u32, ), b: u64) -> u64:
return extend(a[0] * a[1]) + b
@exported
def testEntry(b: (u32, u32)[3]) -> u64:
return {direction}(combine_{direction}, 10000, b)
"""
suite = Suite(code_py)
result = suite.run_code(((2, 5), (25, 4), (125, 8)))
assert 11110 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('direction, exp_result', [
('foldl', -55, ),
('foldr', -5, ),
])
def test_foldable_foldl_foldr_result(direction, exp_result):
# See https://stackoverflow.com/a/13280185
code_py = f"""
def i32_sub(l: i32, r: i32) -> i32:
return l - r
@exported
def testEntry(b: i32[10]) -> i32:
return {direction}(i32_sub, 0, b)
"""
suite = Suite(code_py)
result = suite.run_code(tuple(range(1, 11)))
assert exp_result == result.returned_value
@pytest.mark.integration_test
def test_foldable_foldl_bytes():
code_py = """
def u8_or(l: u8, r: u8) -> u8:
return l | r
@exported
def testEntry(b: bytes) -> u8:
return foldl(u8_or, 128, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'')
assert 128 == result.returned_value
result = suite.run_code(b'\x80')
assert 128 == result.returned_value
result = suite.run_code(b'\x80\x40')
assert 192 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10')
assert 240 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
assert 255 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('in_typ', ['i8', 'i8[3]'])
def test_foldable_argument_must_be_a_function(in_typ):
code_py = f"""
@exported
def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
return foldl(x, y, z)
"""
r_in_typ = in_typ.replace('[', '\\[').replace(']', '\\]')
with pytest.raises(Type3Exception, match=f'{r_in_typ} must be a function instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_foldable_argument_must_be_right_function():
code_py = """
def foo(l: i32, r: i64) -> i64:
return extend(l) + r
@exported
def testEntry(i: i64, l: i64[3]) -> i64:
return foldr(foo, i, l)
"""
with pytest.raises(Type3Exception, match=r'\(i64 -> i64 -> i64\) must be \(i32 -> i64 -> i64\) instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
def test_foldable_invalid_return_type(): def test_foldable_invalid_return_type():
@ -45,7 +157,7 @@ def testEntry(x: i32[5]) -> f64:
return sum(x) return sum(x)
""" """
with pytest.raises(Type3Exception, match='i32 must be f64 instead'): with pytest.raises(Type3Exception, match='f64 must be i32 instead'):
Suite(code_py).run_code((4, 5, 6, 7, 8, )) Suite(code_py).run_code((4, 5, 6, 7, 8, ))
@pytest.mark.integration_test @pytest.mark.integration_test