Removes the special casing for foldl

Now both dynamic and static arrays can be fully fold'ed.

Also, integration tests now don't dump their stuff without
VERBOSE=1, this speeds up the tests suite by a factor of 9.
This commit is contained in:
Johan B.W. de Vries 2025-04-27 12:54:34 +02:00
parent 46b06dbcf1
commit 2091d1d34a
16 changed files with 612 additions and 277 deletions

View File

@ -31,3 +31,5 @@
- Functions don't seem to be a thing on typing level yet? - Functions don't seem to be a thing on typing level yet?
- Related to the FIXME in phasm_type3? - Related to the FIXME in phasm_type3?
- Type constuctor should also be able to constuct placeholders - somehow. - Type constuctor should also be able to constuct placeholders - somehow.
- Read https://bytecodealliance.org/articles/multi-value-all-the-wasm

View File

@ -105,10 +105,6 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}' return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements: def statement(inp: ourlang.Statement) -> Statements:

View File

@ -4,12 +4,12 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
import struct import struct
from typing import List from typing import List
from . import codestyle, ourlang, prelude, wasm from . import ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
from .stdlib.types import TYPE_INFO_CONSTRUCTED, TYPE_INFO_MAP from .stdlib.types import TYPE_INFO_CONSTRUCTED, TYPE_INFO_MAP
from .type3.functions import TypeVariable from .type3.functions import FunctionArgument, TypeVariable
from .type3.routers import NoRouteForTypeException, TypeApplicationRouter from .type3.routers import NoRouteForTypeException, TypeApplicationRouter
from .type3.typeclasses import Type3ClassMethod from .type3.typeclasses import Type3ClassMethod
from .type3.types import ( from .type3.types import (
@ -273,6 +273,10 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
type_var_map[type_var] = arg_expr.type3 type_var_map[type_var] = arg_expr.type3
continue continue
if isinstance(type_var, FunctionArgument):
# Fixed type, not part of the lookup requirements
continue
raise NotImplementedError(type_var, arg_expr.type3) raise NotImplementedError(type_var, arg_expr.type3)
router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.operator] router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.operator]
@ -298,6 +302,10 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
type_var_map[type_var] = arg_expr.type3 type_var_map[type_var] = arg_expr.type3
continue continue
if isinstance(type_var, FunctionArgument):
# Fixed type, not part of the lookup requirements
continue
raise NotImplementedError(type_var, arg_expr.type3) raise NotImplementedError(type_var, arg_expr.type3)
router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.function] router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.function]
@ -359,90 +367,8 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
))) )))
return return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, mod, inp)
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Fold) -> None:
"""
Compile: Fold expression
"""
assert inp.type3 is not None, TYPE3_ASSERTION_ERROR
if inp.iter.type3 is not prelude.bytes_:
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
wgn.add_statement('nop', comment='acu :: u8')
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
wgn.add_statement('nop', comment='adr :: bytes*')
adr_var = wgn.temp_var_i32('fold_i32_adr')
wgn.add_statement('nop', comment='len :: i32')
len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base')
expression(wgn, mod, inp.base)
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, mod, inp.iter)
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)')
wgn.local.get(adr_var)
wgn.i32.load()
wgn.local.set(len_var)
wgn.add_statement('nop', comment='i = 0')
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
wgn.i32.const(0)
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i < len')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
with wgn.if_():
# From here on, adr_var is the address of byte we're referencing
# This is akin to calling stdlib_types.__subscript_bytes__
# But since we already know we are inside of bounds,
# can just bypass it and load the memory directly.
wgn.local.get(adr_var)
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
wgn.i32.add()
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='while True')
with wgn.loop():
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
wgn.local.get(acu_var)
# Get the next byte, write back the address
wgn.local.get(adr_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.tee(adr_var)
wgn.i32.load8_u()
wgn.add_statement('call', f'${inp.func.name}')
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='i = i + 1')
wgn.local.get(idx_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i >= len: break')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
wgn.br_if(0)
# return acu
wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementReturn) -> None: def statement_return(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementReturn) -> None:
""" """
Compile: Return statement Compile: Return statement

View File

@ -1,7 +1,6 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
import enum
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from . import prelude from . import prelude
@ -219,36 +218,6 @@ class AccessStructMember(Expression):
self.struct_type3 = struct_type3 self.struct_type3 = struct_type3
self.member = member self.member = member
class Fold(Expression):
"""
A (left or right) fold
"""
class Direction(enum.Enum):
"""
Which direction to fold in
"""
LEFT = 0
RIGHT = 1
dir: Direction
func: 'Function'
base: Expression
iter: Expression
def __init__(
self,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__()
self.dir = dir_
self.func = func
self.base = base
self.iter = iter_
class Statement: class Statement:
""" """
A statement within a function A statement within a function

View File

@ -14,7 +14,6 @@ from .ourlang import (
ConstantStruct, ConstantStruct,
ConstantTuple, ConstantTuple,
Expression, Expression,
Fold,
Function, Function,
FunctionCall, FunctionCall,
FunctionParam, FunctionParam,
@ -467,7 +466,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as expr in FunctionDef') raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]: def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
if node.keywords: if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet? _raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -480,28 +479,6 @@ class OurVisitor:
if node.func.id in PRELUDE_METHODS: if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id] func = PRELUDE_METHODS[node.func.id]
elif node.func.id == 'foldl':
if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic, you cannot return a function
subnode = node.args[0]
if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id]
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
return Fold(
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
)
elif node.func.id in our_locals: elif node.func.id in our_locals:
func = our_locals[node.func.id] func = our_locals[node.func.id]
else: else:

View File

@ -553,12 +553,21 @@ instance_type_class(Promotable, f32, f64, methods={
Foldable = Type3Class('Foldable', (t, ), methods={ Foldable = Type3Class('Foldable', (t, ), methods={
'sum': [t(a), a], 'sum': [t(a), a],
'foldl': [[b, a, b], b, t(a), b],
'foldr': [[a, b, b], b, t(a), b],
}, operators={}, additional_context={ }, operators={}, additional_context={
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))], 'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
}) })
instance_type_class(Foldable, dynamic_array, methods={
'sum': stdtypes.dynamic_array_sum,
'foldl': stdtypes.dynamic_array_foldl,
'foldr': stdtypes.dynamic_array_foldr,
})
instance_type_class(Foldable, static_array, methods={ instance_type_class(Foldable, static_array, methods={
'sum': stdtypes.static_array_sum, 'sum': stdtypes.static_array_sum,
'foldl': stdtypes.static_array_foldl,
'foldr': stdtypes.static_array_foldr,
}) })
bytes_ = dynamic_array(u8) bytes_ = dynamic_array(u8)

View File

@ -1059,12 +1059,16 @@ def dynamic_array_sized_len(g: Generator, tv_map: TypeVariableLookup) -> None:
# The length is stored in the first 4 bytes # The length is stored in the first 4 bytes
g.i32.load() g.i32.load()
def static_array_sized_len(g: Generator, tv_map: TypeVariableLookup) -> None: def static_array_sized_len(g: Generator, tvl: TypeVariableLookup) -> None:
assert len(tv_map) == 1 tv_map, tc_map = tvl
sa_type, sa_len = next(iter(tv_map.values()))
assert isinstance(sa_type, Type3)
assert isinstance(sa_len, IntType3)
tvn_map = {
x.name: y
for x, y in tv_map.items()
}
sa_len = tvn_map['a*']
assert isinstance(sa_len, IntType3)
g.i32.const(sa_len.value) g.i32.const(sa_len.value)
## ### ## ###
@ -1137,9 +1141,23 @@ def f32_f64_demote(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map del tv_map
g.f32.demote_f64() g.f32.demote_f64()
def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: ## ###
assert len(tv_map) == 1 ## Foldable
sa_type, sa_len = next(iter(tv_map.values()))
def dynamic_array_sum(g: Generator, tvl: TypeVariableLookup) -> None:
raise NotImplementedError('dynamic_array_sum')
def static_array_sum(g: Generator, tvl: TypeVariableLookup) -> None:
tv_map, tc_map = tvl
tvn_map = {
x.name: y
for x, y in tv_map.items()
}
sa_type = tvn_map['a']
sa_len = tvn_map['a*']
assert isinstance(sa_type, Type3) assert isinstance(sa_type, Type3)
assert isinstance(sa_len, IntType3) assert isinstance(sa_len, IntType3)
@ -1171,7 +1189,7 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
g.nop(comment=f'Start sum for {sa_type.name}[{sa_len.value}]') g.nop(comment=f'Start sum for {sa_type.name}[{sa_len.value}]')
g.local.set(sum_adr) g.local.set(sum_adr)
# stop = adr + ar_len * sa_type_alloc_size # stop = adr + ar_len * sa_type_info.alloc_size
# Stack: [] # Stack: []
g.nop(comment='Calculate address at which to stop looping') g.nop(comment='Calculate address at which to stop looping')
g.local.get(sum_adr) g.local.get(sum_adr)
@ -1186,7 +1204,7 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
g.add_statement(sa_type_info.wasm_load_func) g.add_statement(sa_type_info.wasm_load_func)
# Since we did the first one, increase adr # Since we did the first one, increase adr
# adr = adr + sa_type_alloc_size # adr = adr + sa_type_info.alloc_size
# Stack: [sum] -> [sum] # Stack: [sum] -> [sum]
g.local.get(sum_adr) g.local.get(sum_adr)
g.i32.const(sa_type_info.alloc_size) g.i32.const(sa_type_info.alloc_size)
@ -1194,15 +1212,15 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
g.local.set(sum_adr) g.local.set(sum_adr)
if sa_len.value > 1: if sa_len.value > 1:
with g.loop(params=[sa_type_info.wasm_type().to_wat()], result=sa_type_info.wasm_type().to_wat()): with g.loop(params=[sa_type_info.wasm_type], result=sa_type_info.wasm_type):
# sum = sum + *adr # sum = sum + *adr
# Stack: [sum] -> [sum + *adr] # Stack: [sum] -> [sum + *adr]
g.nop(comment='Add array value') g.nop(comment='Add array value')
g.local.get(sum_adr) g.local.get(sum_adr)
g.add_statement(sa_type_info.wasm_load_func) g.add_statement(sa_type_info.wasm_load_func)
sa_type_add_gen(g, {}) sa_type_add_gen(g, ({}, {}, ))
# adr = adr + sa_type_alloc_size # adr = adr + sa_type_info.alloc_size
# Stack: [sum] -> [sum] # Stack: [sum] -> [sum]
g.nop(comment='Calculate address of the next value') g.nop(comment='Calculate address of the next value')
g.local.get(sum_adr) g.local.get(sum_adr)
@ -1219,3 +1237,294 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]') g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]')
# End result: [sum] # End result: [sum]
def dynamic_array_foldl(g: Generator, tvl: TypeVariableLookup) -> None:
tv_map, tc_map = tvl
tvn_map = {
x.name: y
for x, y in tv_map.items()
}
sa_type = tvn_map['a']
res_type = tvn_map['b']
assert isinstance(sa_type, Type3)
assert isinstance(res_type, Type3)
sa_type_info = TYPE_INFO_MAP.get(sa_type.name, TYPE_INFO_CONSTRUCTED)
res_type_info = TYPE_INFO_MAP.get(res_type.name, TYPE_INFO_CONSTRUCTED)
# Definitions
fold_adr = g.temp_var(i32('fold_adr'))
fold_stop = g.temp_var(i32('fold_stop'))
fold_init = g.temp_var_t(res_type_info.wasm_type, 'fold_init')
fold_func = g.temp_var(i32('fold_func'))
fold_len = g.temp_var(i32('fold_len'))
with g.block(params=['i32', res_type_info.wasm_type, 'i32'], result=res_type_info.wasm_type, comment=f'foldl a={sa_type.name} b={res_type.name}'):
# Stack: [fn*, b, sa*] -> [fn*, b]
g.local.tee(fold_adr) # Store address, but also keep it for loading the length
g.i32.load() # Load the length
g.local.set(fold_len) # Store the length
# Stack: [fn*, b] -> [fn*]
g.local.set(fold_init)
# Stack: [fn*] -> []
g.local.set(fold_func)
# Stack: [] -> [b]
g.nop(comment='No applications if array is empty')
g.local.get(fold_init)
g.local.get(fold_len)
g.i32.eqz() # If the array is empty
g.br_if(0) # Then the base value is the result
# Stack: [b] -> [b] ; fold_adr=fold_adr + 4
g.nop(comment='Skip the header')
g.local.get(fold_adr)
g.i32.const(4)
g.i32.add()
g.local.set(fold_adr)
# Stack: [b] -> [b]
g.nop(comment='Apply the first function call')
g.local.get(fold_adr)
g.add_statement(sa_type_info.wasm_load_func)
g.local.get(fold_func)
g.call_indirect([res_type_info.wasm_type, sa_type_info.wasm_type], res_type_info.wasm_type)
# Stack: [b] -> [b]
g.nop(comment='No loop if there is only one item')
g.local.get(fold_len)
g.i32.const(1)
g.i32.eq()
g.br_if(0) # just one value, don't need to loop
# Stack: [b] -> [b] ; fold_stop=fold_adr + (sa_len.value * sa_type_info.alloc_size)
g.nop(comment='Calculate address at which to stop looping')
g.local.get(fold_adr)
g.local.get(fold_len)
g.i32.const(sa_type_info.alloc_size)
g.i32.mul()
g.i32.add()
g.local.set(fold_stop)
# Stack: [b] -> [b] ; fold_adr = fold_adr + sa_type_info.alloc_size
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_info.alloc_size)
g.i32.add()
g.local.set(fold_adr)
with g.loop(params=[res_type_info.wasm_type], result=res_type_info.wasm_type):
# Stack: [b] -> [b]
g.nop(comment='Apply function call')
g.local.get(fold_adr)
g.add_statement(sa_type_info.wasm_load_func)
g.local.get(fold_func)
g.call_indirect([res_type_info.wasm_type, sa_type_info.wasm_type], res_type_info.wasm_type)
# Stack: [b] -> [b] ; fold_adr = fold_adr + sa_type_info.alloc_size
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_info.alloc_size)
g.i32.add()
g.local.tee(fold_adr)
# loop if adr > stop
# Stack: [b] -> [b]
g.nop(comment='Check if address exceeds array bounds')
g.local.get(fold_stop)
g.i32.lt_u()
g.br_if(0)
# Stack: [b]
def static_array_foldl(g: Generator, tvl: TypeVariableLookup) -> None:
tv_map, tc_map = tvl
tvn_map = {
x.name: y
for x, y in tv_map.items()
}
sa_type = tvn_map['a']
sa_len = tvn_map['a*']
res_type = tvn_map['b']
assert isinstance(sa_type, Type3)
assert isinstance(sa_len, IntType3)
assert isinstance(res_type, Type3)
sa_type_info = TYPE_INFO_MAP.get(sa_type.name, TYPE_INFO_CONSTRUCTED)
res_type_info = TYPE_INFO_MAP.get(res_type.name, TYPE_INFO_CONSTRUCTED)
# Definitions
fold_adr = g.temp_var(i32('fold_adr'))
fold_stop = g.temp_var(i32('fold_stop'))
fold_init = g.temp_var_t(res_type_info.wasm_type, 'fold_init')
fold_func = g.temp_var(i32('fold_func'))
with g.block(params=['i32', res_type_info.wasm_type, 'i32'], result=res_type_info.wasm_type, comment=f'foldl a={sa_type.name} a*={sa_len.value} b={res_type.name}'):
# Stack: [fn*, b, sa*] -> [fn*, b]
g.local.set(fold_adr)
# Stack: [fn*, b] -> [fn*]
g.local.set(fold_init)
# Stack: [fn*] -> []
g.local.set(fold_func)
if sa_len.value < 1:
g.local.get(fold_init)
return
# Stack: [] -> [b]
g.nop(comment='Apply the first function call')
g.local.get(fold_init)
g.local.get(fold_adr)
g.add_statement(sa_type_info.wasm_load_func)
g.local.get(fold_func)
g.call_indirect([res_type_info.wasm_type, sa_type_info.wasm_type], res_type_info.wasm_type)
if sa_len.value > 1:
# Stack: [b] -> [b] ; fold_stop=fold_adr + (sa_len.value * sa_type_info.alloc_size)
g.nop(comment='Calculate address at which to stop looping')
g.local.get(fold_adr)
g.i32.const(sa_len.value * sa_type_info.alloc_size)
g.i32.add()
g.local.set(fold_stop)
# Stack: [b] -> [b] ; fold_adr = fold_adr + sa_type_info.alloc_size
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_info.alloc_size)
g.i32.add()
g.local.set(fold_adr)
with g.loop(params=[res_type_info.wasm_type], result=res_type_info.wasm_type):
# Stack: [b] -> [b]
g.nop(comment='Apply function call')
g.local.get(fold_adr)
g.add_statement(sa_type_info.wasm_load_func)
g.local.get(fold_func)
g.call_indirect([res_type_info.wasm_type, sa_type_info.wasm_type], res_type_info.wasm_type)
# Stack: [b] -> [b] ; fold_adr = fold_adr + sa_type_info.alloc_size
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_info.alloc_size)
g.i32.add()
g.local.tee(fold_adr)
# loop if adr > stop
# Stack: [b] -> [b]
g.nop(comment='Check if address exceeds array bounds')
g.local.get(fold_stop)
g.i32.lt_u()
g.br_if(0)
# else: just one value, don't need to loop
# Stack: [b]
def dynamic_array_foldr(g: Generator, tvl: TypeVariableLookup) -> None:
raise NotImplementedError('dynamic_array_foldr')
def static_array_foldr(g: Generator, tvl: TypeVariableLookup) -> None:
tv_map, tc_map = tvl
tvn_map = {
x.name: y
for x, y in tv_map.items()
}
sa_type = tvn_map['a']
sa_len = tvn_map['a*']
res_type = tvn_map['b']
assert isinstance(sa_type, Type3)
assert isinstance(sa_len, IntType3)
assert isinstance(res_type, Type3)
sa_type_info = TYPE_INFO_MAP.get(sa_type.name, TYPE_INFO_CONSTRUCTED)
res_type_info = TYPE_INFO_MAP.get(res_type.name, TYPE_INFO_CONSTRUCTED)
# Definitions
fold_adr = g.temp_var(i32('fold_adr'))
fold_stop = g.temp_var(i32('fold_stop'))
fold_tmp = g.temp_var_t(res_type_info.wasm_type, 'fold_tmp')
fold_func = g.temp_var(i32('fold_func'))
with g.block(params=['i32', res_type_info.wasm_type, 'i32'], result=res_type_info.wasm_type, comment=f'foldr a={sa_type.name} a*={sa_len.value} b={res_type.name}'):
# Stack: [fn*, b, sa*] -> [fn*, b] ; fold_adr=fn*, fold_tmp=b, fold_func=fn*
g.local.set(fold_adr)
# Stack: [fn*, b] -> [fn*]
g.local.set(fold_tmp)
# Stack: [fn*] -> []
g.local.set(fold_func)
if sa_len.value < 1:
g.local.get(fold_tmp)
return
# Stack: [] -> [] ; fold_stop=fold_adr
g.nop(comment='Calculate address at which to stop looping')
g.local.get(fold_adr)
g.local.set(fold_stop)
# Stack: [] -> [] ; fold_adr=fold_adr + (sa_len.value - 1) * sa_type_info.alloc_size
g.nop(comment='Calculate address at which to stop looping')
g.local.get(fold_adr)
g.i32.const((sa_len.value - 1) * sa_type_info.alloc_size)
g.i32.add()
g.local.set(fold_adr)
# Stack: [] -> [b]
g.nop(comment='Get the init value and first array value as starting point')
g.local.get(fold_adr)
g.add_statement(sa_type_info.wasm_load_func)
g.local.get(fold_tmp)
g.local.get(fold_func)
g.call_indirect([sa_type_info.wasm_type, res_type_info.wasm_type], res_type_info.wasm_type)
if sa_len.value > 1:
# Stack: [b] -> [b] ; fold_adr = fold_adr - sa_type_info.alloc_size
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_info.alloc_size)
g.i32.sub()
g.local.set(fold_adr)
with g.loop(params=[res_type_info.wasm_type], result=res_type_info.wasm_type):
g.nop(comment='Apply function call')
# Stack [b] since we don't have proper stack switching opcodes
# Stack: [b] -> []
g.local.set(fold_tmp)
# Stack: [] -> [a]
g.local.get(fold_adr)
g.add_statement(sa_type_info.wasm_load_func)
# Stack [a] -> [a, b]
g.local.get(fold_tmp)
# Stack [a, b] -> [b]
g.local.get(fold_func)
g.call_indirect([sa_type_info.wasm_type, res_type_info.wasm_type], res_type_info.wasm_type)
# Stack: [b] -> [b] ; fold_adr = fold_adr - sa_type_info.alloc_size
g.nop(comment='Calculate address of the next value')
g.local.get(fold_adr)
g.i32.const(sa_type_info.alloc_size)
g.i32.sub()
g.local.tee(fold_adr)
# loop if adr >= stop
# Stack: [b] -> [b]
g.nop(comment='Check if address exceeds array bounds')
g.local.get(fold_stop)
g.i32.ge_u()
g.br_if(0)
# else: just one value, don't need to loop
# Stack: [b]

View File

@ -183,7 +183,7 @@ class SameTypeArgumentConstraint(ConstraintBase):
self.arg_var = arg_var self.arg_var = arg_var
def check(self) -> CheckResult: def check(self) -> CheckResult:
if self.tc_var.resolve_as is None or self.arg_var.resolve_as is None: if self.tc_var.resolve_as is None:
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
tc_typ = self.tc_var.resolve_as tc_typ = self.tc_var.resolve_as
@ -200,12 +200,16 @@ class SameTypeArgumentConstraint(ConstraintBase):
# So we can let the MustImplementTypeClassConstraint handle it. # So we can let the MustImplementTypeClassConstraint handle it.
return None return None
if isinstance(tc_typ.application, TypeApplication_Type):
return [SameTypeConstraint(
tc_typ.application.arguments[0],
self.arg_var,
comment=self.comment,
)]
# FIXME: This feels sketchy. Shouldn't the type variable # FIXME: This feels sketchy. Shouldn't the type variable
# have the exact same number as arguments? # have the exact same number as arguments?
if isinstance(tc_typ.application, TypeApplication_TypeInt): if isinstance(tc_typ.application, TypeApplication_TypeInt):
if tc_typ.application.arguments[0] == arg_typ:
return None
return [SameTypeConstraint( return [SameTypeConstraint(
tc_typ.application.arguments[0], tc_typ.application.arguments[0],
self.arg_var, self.arg_var,
@ -346,14 +350,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
__slots__ = ('context', 'type_class3', 'types', ) __slots__ = ('context', 'type_class3', 'types', )
context: Context context: Context
type_class3: Union[str, Type3Class] type_class3: Type3Class
types: list[Type3OrPlaceholder] types: list[Type3OrPlaceholder]
DATA = { def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
'dynamic_array': {'Foldable'},
}
def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
self.context = context self.context = context
@ -381,13 +381,9 @@ class MustImplementTypeClassConstraint(ConstraintBase):
assert len(typ_list) == len(self.types) assert len(typ_list) == len(self.types)
if isinstance(self.type_class3, Type3Class):
key = (self.type_class3, tuple(typ_list), ) key = (self.type_class3, tuple(typ_list), )
if key in self.context.type_class_instances_existing: if key in self.context.type_class_instances_existing:
return None return None
else:
if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()):
return None
typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name
typ_name_list = ' '.join(x.name for x in typ_list) typ_name_list = ' '.join(x.name for x in typ_list)

View File

@ -202,6 +202,10 @@ def _expression_function_call(
yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment) yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment)
continue continue
if isinstance(sig_part, FunctionArgument):
yield SameTypeConstraint(func_var_map[sig_part], arg_placeholders[arg_expr], comment=comment)
continue
raise NotImplementedError(sig_part) raise NotImplementedError(sig_part)
return return
@ -265,19 +269,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
return return
if isinstance(inp, ourlang.Fold):
base_phft = PlaceholderForType([inp.base])
iter_phft = PlaceholderForType([inp.iter])
yield from expression(ctx, inp.base, base_phft)
yield from expression(ctx, inp.iter, iter_phft)
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft,
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft])
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:

View File

@ -3,6 +3,7 @@ from typing import Any, Callable
from .functions import ( from .functions import (
TypeConstructorVariable, TypeConstructorVariable,
TypeVariable, TypeVariable,
TypeVariableApplication_Nullary,
TypeVariableApplication_Unary, TypeVariableApplication_Unary,
) )
from .typeclasses import Type3ClassArgs from .typeclasses import Type3ClassArgs
@ -60,7 +61,10 @@ class TypeApplicationRouter[S, R]:
raise NoRouteForTypeException(arg0, typ) raise NoRouteForTypeException(arg0, typ)
TypeVariableLookup = dict[TypeVariable, tuple[KindArgument, ...]] TypeVariableLookup = tuple[
dict[TypeVariable, KindArgument],
dict[TypeConstructorVariable, TypeConstructor_Base[Any]],
]
class TypeClassArgsRouter[S, R]: class TypeClassArgsRouter[S, R]:
""" """
@ -95,11 +99,12 @@ class TypeClassArgsRouter[S, R]:
def __call__(self, arg0: S, tv_map: dict[TypeVariable, Type3]) -> R: def __call__(self, arg0: S, tv_map: dict[TypeVariable, Type3]) -> R:
key: list[Type3 | TypeConstructor_Base[Any]] = [] key: list[Type3 | TypeConstructor_Base[Any]] = []
arguments: TypeVariableLookup = {} arguments: TypeVariableLookup = (dict(tv_map), {}, )
for tc_arg in self.args: for tc_arg in self.args:
if isinstance(tc_arg, TypeVariable): if isinstance(tc_arg, TypeVariable):
key.append(tv_map[tc_arg]) key.append(tv_map[tc_arg])
arguments[0][tc_arg] = tv_map[tc_arg]
continue continue
for tvar, typ in tv_map.items(): for tvar, typ in tv_map.items():
@ -108,16 +113,24 @@ class TypeClassArgsRouter[S, R]:
continue continue
key.append(typ.application.constructor) key.append(typ.application.constructor)
arguments[1][tc_arg] = typ.application.constructor
if isinstance(tvar.application, TypeVariableApplication_Unary): if isinstance(tvar.application, TypeVariableApplication_Unary):
if isinstance(typ.application, TypeApplication_Type): if isinstance(typ.application, TypeApplication_Type):
arguments[tvar.application.arguments] = typ.application.arguments da_type, = typ.application.arguments
sa_type_tv = tvar.application.arguments
arguments[0][sa_type_tv] = da_type
continue continue
# FIXME: This feels sketchy. Shouldn't the type variable # FIXME: This feels sketchy. Shouldn't the type variable
# have the exact same number as arguments? # have the exact same number as arguments?
if isinstance(typ.application, TypeApplication_TypeInt): if isinstance(typ.application, TypeApplication_TypeInt):
arguments[tvar.application.arguments] = typ.application.arguments sa_type, sa_len = typ.application.arguments
sa_type_tv = tvar.application.arguments
sa_len_tv = TypeVariable(sa_type_tv.name + '*', TypeVariableApplication_Nullary(None, None))
arguments[0][sa_type_tv] = sa_type
arguments[0][sa_len_tv] = sa_len
continue continue
raise NotImplementedError(tvar.application, typ.application) raise NotImplementedError(tvar.application, typ.application)

View File

@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Mapping, Optional, Union from typing import Dict, Iterable, List, Mapping, Optional
from .functions import ( from .functions import (
Constraint_TypeClassInstanceExists, Constraint_TypeClassInstanceExists,
@ -42,8 +42,8 @@ class Type3Class:
self, self,
name: str, name: str,
args: Type3ClassArgs, args: Type3ClassArgs,
methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]], methods: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]], operators: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
inherited_classes: Optional[List['Type3Class']] = None, inherited_classes: Optional[List['Type3Class']] = None,
additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None, additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None,
) -> None: ) -> None:
@ -71,19 +71,23 @@ class Type3Class:
return self.name return self.name
def _create_signature( def _create_signature(
method_arg_list: Iterable[Type3 | TypeVariable], method_arg_list: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]],
type_class3: Type3Class, type_class3: Type3Class,
) -> FunctionSignature: ) -> FunctionSignature:
context = TypeVariableContext() context = TypeVariableContext()
if not isinstance(type_class3.args[0], TypeConstructorVariable): if not isinstance(type_class3.args[0], TypeConstructorVariable):
context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args)) context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args))
signature_args: list[Type3 | TypeVariable] = [] signature_args: list[Type3 | TypeVariable | list[Type3 | TypeVariable]] = []
for method_arg in method_arg_list: for method_arg in method_arg_list:
if isinstance(method_arg, Type3): if isinstance(method_arg, Type3):
signature_args.append(method_arg) signature_args.append(method_arg)
continue continue
if isinstance(method_arg, list):
signature_args.append(method_arg)
continue
if isinstance(method_arg, TypeVariable): if isinstance(method_arg, TypeVariable):
type_constructor = method_arg.application.constructor type_constructor = method_arg.application.constructor
if type_constructor is None: if type_constructor is None:

View File

@ -170,23 +170,32 @@ class Generator_Local:
self.generator.add_statement('local.tee', variable.name_ref, comment=comment) self.generator.add_statement('local.tee', variable.name_ref, comment=comment)
class GeneratorBlock: class GeneratorBlock:
def __init__(self, generator: 'Generator', name: str, params: Iterable[str] = (), result: str | None = None) -> None: def __init__(
self,
generator: 'Generator',
name: str,
params: Iterable[str | Type[wasm.WasmType]] = (),
result: str | Type[wasm.WasmType] | None = None,
comment: str | None = None,
) -> None:
self.generator = generator self.generator = generator
self.name = name self.name = name
self.params = params self.params = params
self.result = result self.result = result
self.comment = comment
def __enter__(self) -> None: def __enter__(self) -> None:
stmt = self.name stmt = self.name
if self.params: if self.params:
stmt = f'{stmt} ' + ' '.join( stmt = f'{stmt} ' + ' '.join(
f'(param {typ})' f'(param {typ})' if isinstance(typ, str) else f'(param {typ().to_wat()})'
for typ in self.params for typ in self.params
) )
if self.result: if self.result:
stmt = f'{stmt} (result {self.result})' result = self.result if isinstance(self.result, str) else self.result().to_wat()
stmt = f'{stmt} (result {result})'
self.generator.add_statement(stmt) self.generator.add_statement(stmt, comment=self.comment)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if not exc_type: if not exc_type:
@ -208,7 +217,7 @@ class Generator:
# 2.4.5 Control Instructions # 2.4.5 Control Instructions
self.nop = functools.partial(self.add_statement, 'nop') self.nop = functools.partial(self.add_statement, 'nop')
self.unreachable = functools.partial(self.add_statement, 'unreachable') self.unreachable = functools.partial(self.add_statement, 'unreachable')
# block self.block = functools.partial(GeneratorBlock, self, 'block')
self.loop = functools.partial(GeneratorBlock, self, 'loop') self.loop = functools.partial(GeneratorBlock, self, 'loop')
self.if_ = functools.partial(GeneratorBlock, self, 'if') self.if_ = functools.partial(GeneratorBlock, self, 'if')
# br # br
@ -224,6 +233,16 @@ class Generator:
def call(self, function: wasm.Function) -> None: def call(self, function: wasm.Function) -> None:
self.add_statement('call', f'${function.name}') self.add_statement('call', f'${function.name}')
def call_indirect(self, params: Iterable[Type[wasm.WasmType]], result: Type[wasm.WasmType]) -> None:
param_str = ' '.join(
x().to_wat()
for x in params
)
result_str = result().to_wat()
self.add_statement(f'call_indirect (param {param_str}) (result {result_str})')
def add_statement(self, name: str, *args: str, comment: Optional[str] = None) -> None: def add_statement(self, name: str, *args: str, comment: Optional[str] = None) -> None:
self.statements.append(wasm.Statement(name, *args, comment=comment)) self.statements.append(wasm.Statement(name, *args, comment=comment))
@ -234,6 +253,28 @@ class Generator:
return var.__class__(varname) return var.__class__(varname)
def temp_var_t(self, typ: Type[wasm.WasmType], name: str) -> VarType_Base:
idx = 0
while (varname := f'__{name}_tmp_var_{idx}__') in self.locals:
idx += 1
if typ is wasm.WasmTypeInt32:
return VarType_u8(varname)
if typ is wasm.WasmTypeInt32:
return VarType_i32(varname)
if typ is wasm.WasmTypeInt64:
return VarType_i64(varname)
if typ is wasm.WasmTypeFloat32:
return VarType_f32(varname)
if typ is wasm.WasmTypeFloat64:
return VarType_f64(varname)
raise NotImplementedError(typ)
def temp_var_i32(self, infix: str) -> VarType_i32: def temp_var_i32(self, infix: str) -> VarType_i32:
return self.temp_var(VarType_i32(infix)) return self.temp_var(VarType_i32(infix))

View File

@ -1,3 +1,4 @@
import os
import struct import struct
import sys import sys
from typing import Any, Generator, Iterable, List, TextIO, Union from typing import Any, Generator, Iterable, List, TextIO, Union
@ -36,24 +37,37 @@ class Suite:
def __init__(self, code_py: str) -> None: def __init__(self, code_py: str) -> None:
self.code_py = code_py self.code_py = code_py
def run_code(self, *args: Any, runtime: str = 'wasmtime', func_name: str = 'testEntry', imports: runners.Imports = None, do_format_check: bool = True) -> Any: def run_code(
self,
*args: Any,
runtime: str = 'wasmtime',
func_name: str = 'testEntry',
imports: runners.Imports = None,
do_format_check: bool = True,
verbose: bool | None = None,
) -> Any:
""" """
Compiles the given python code into wasm and Compiles the given python code into wasm and
then runs it then runs it
Returned is an object with the results set Returned is an object with the results set
""" """
if verbose is None:
verbose = bool(os.environ.get('VERBOSE'))
class_ = RUNNER_CLASS_MAP[runtime] class_ = RUNNER_CLASS_MAP[runtime]
runner = class_(self.code_py) runner = class_(self.code_py)
if verbose:
write_header(sys.stderr, 'Phasm') write_header(sys.stderr, 'Phasm')
runner.dump_phasm_code(sys.stderr) runner.dump_phasm_code(sys.stderr)
runner.parse() runner.parse(verbose=verbose)
runner.compile_ast() runner.compile_ast()
runner.compile_wat() runner.compile_wat()
if verbose:
write_header(sys.stderr, 'Assembly') write_header(sys.stderr, 'Assembly')
runner.dump_wasm_wat(sys.stderr) runner.dump_wasm_wat(sys.stderr)
@ -70,6 +84,7 @@ class Suite:
wasm_args: List[Union[float, int]] = [] wasm_args: List[Union[float, int]] = []
if args: if args:
if verbose:
write_header(sys.stderr, 'Memory (pre alloc)') write_header(sys.stderr, 'Memory (pre alloc)')
runner.interpreter_dump_memory(sys.stderr) runner.interpreter_dump_memory(sys.stderr)
@ -95,6 +110,7 @@ class Suite:
except NoRouteForTypeException: except NoRouteForTypeException:
raise NotImplementedError(arg_typ, arg) raise NotImplementedError(arg_typ, arg)
if verbose:
write_header(sys.stderr, 'Memory (pre run)') write_header(sys.stderr, 'Memory (pre run)')
runner.interpreter_dump_memory(sys.stderr) runner.interpreter_dump_memory(sys.stderr)
@ -107,6 +123,7 @@ class Suite:
result.returned_value, result.returned_value,
) )
if verbose:
write_header(sys.stderr, 'Memory (post run)') write_header(sys.stderr, 'Memory (post run)')
runner.interpreter_dump_memory(sys.stderr) runner.interpreter_dump_memory(sys.stderr)

View File

@ -32,12 +32,12 @@ class RunnerBase:
""" """
_dump_code(textio, self.phasm_code) _dump_code(textio, self.phasm_code)
def parse(self) -> None: def parse(self, verbose: bool = True) -> None:
""" """
Parses the Phasm code into an AST Parses the Phasm code into an AST
""" """
self.phasm_ast = phasm_parse(self.phasm_code) self.phasm_ast = phasm_parse(self.phasm_code)
phasm_type3(self.phasm_ast, verbose=True) phasm_type3(self.phasm_ast, verbose=verbose)
def compile_ast(self) -> None: def compile_ast(self) -> None:
""" """
@ -120,6 +120,8 @@ class RunnerWasmtime(RunnerBase):
if vartype is int: if vartype is int:
params.append(wasmtime.ValType.i32()) params.append(wasmtime.ValType.i32())
elif vartype is float:
params.append(wasmtime.ValType.f32())
else: else:
raise NotImplementedError raise NotImplementedError
@ -128,6 +130,8 @@ class RunnerWasmtime(RunnerBase):
pass # No return value pass # No return value
elif func.__annotations__['return'] is int: elif func.__annotations__['return'] is int:
results.append(wasmtime.ValType.i32()) results.append(wasmtime.ValType.i32())
elif func.__annotations__['return'] is float:
results.append(wasmtime.ValType.f32())
else: else:
raise NotImplementedError('Return type', func.__annotations__['return']) raise NotImplementedError('Return type', func.__annotations__['return'])

View File

@ -1,61 +0,0 @@
import pytest
from ..helpers import Suite
@pytest.mark.integration_test
def test_foldl_1():
code_py = """
def u8_or(l: u8, r: u8) -> u8:
return l | r
@exported
def testEntry(b: bytes) -> u8:
return foldl(u8_or, 128, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'')
assert 128 == result.returned_value
result = suite.run_code(b'\x80')
assert 128 == result.returned_value
result = suite.run_code(b'\x80\x40')
assert 192 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10')
assert 240 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
assert 255 == result.returned_value
@pytest.mark.integration_test
def test_foldl_2():
code_py = """
def xor(l: u8, r: u8) -> u8:
return l ^ r
@exported
def testEntry(a: bytes, b: bytes) -> u8:
return foldl(xor, 0, a) ^ foldl(xor, 0, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'\x55\x0F', b'\x33\x80')
assert 233 == result.returned_value
@pytest.mark.integration_test
def test_foldl_3():
code_py = """
def xor(l: u32, r: u8) -> u32:
return l ^ extend(r)
@exported
def testEntry(a: bytes) -> u32:
return foldl(xor, 0, a)
"""
suite = Suite(code_py)
result = suite.run_code(b'\x55\x0F\x33\x80')
assert 233 == result.returned_value

View File

@ -9,10 +9,13 @@ from .test_natnum import FLOAT_TYPES, INT_TYPES
@pytest.mark.integration_test @pytest.mark.integration_test
@pytest.mark.parametrize('length', [1, 5, 13]) @pytest.mark.parametrize('length', [1, 5, 13])
@pytest.mark.parametrize('a_type', INT_TYPES + FLOAT_TYPES) @pytest.mark.parametrize('a_type', INT_TYPES + FLOAT_TYPES)
def test_foldable_sum(length, a_type): @pytest.mark.parametrize('static', [True, False])
def test_foldable_sum(length, a_type, static):
typ_arg = str(length) if static else '...'
code_py = f""" code_py = f"""
@exported @exported
def testEntry(x: {a_type}[{length}]) -> {a_type}: def testEntry(x: {a_type}[{typ_arg}]) -> {a_type}:
return sum(x) return sum(x)
""" """
@ -36,6 +39,145 @@ def testEntry(x: Foo[4]) -> Foo:
with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'): with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'):
Suite(code_py).run_code() Suite(code_py).run_code()
@pytest.mark.integration_test
@pytest.mark.parametrize('length', [1, 5, 13])
@pytest.mark.parametrize('direction', ['foldl', 'foldr'])
def test_foldable_foldl_foldr_size(direction, length):
code_py = f"""
def u64_add(l: u64, r: u64) -> u64:
return l + r
@exported
def testEntry(b: u64[{length}]) -> u64:
return {direction}(u64_add, 100, b)
"""
suite = Suite(code_py)
in_put = tuple(range(1, length + 1))
result = suite.run_code(in_put)
assert (100 + sum(in_put)) == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('direction', ['foldl', 'foldr'])
def test_foldable_foldl_foldr_compounded_type(direction):
code_py = f"""
def combine_foldl(b: u64, a: (u32, u32, )) -> u64:
return extend(a[0] * a[1]) + b
def combine_foldr(a: (u32, u32, ), b: u64) -> u64:
return extend(a[0] * a[1]) + b
@exported
def testEntry(b: (u32, u32, )[3]) -> u64:
return {direction}(combine_{direction}, 10000, b)
"""
suite = Suite(code_py)
result = suite.run_code(((2, 5), (25, 4), (125, 8)))
assert 11110 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('in_put, direction, exp_result', [
([], 'foldl', 0, ),
([], 'foldr', 0, ),
([1], 'foldl', -1, ),
([1], 'foldr', 1, ),
([1,2], 'foldl', -3, ),
([1,2], 'foldr', -1, ),
([1,2,3], 'foldl', -6, ),
([1,2,3], 'foldr', 2, ),
([1,2,3,4], 'foldl', -10, ),
([1,2,3,4], 'foldr', -2, ),
([1,2,3,4,5], 'foldl', -15, ),
([1,2,3,4,5], 'foldr', 3, ),
([1,2,3,4,5,6], 'foldl', -21, ),
([1,2,3,4,5,6], 'foldr', -3, ),
([1,2,3,4,5,6,7], 'foldl', -28, ),
([1,2,3,4,5,6,7], 'foldr', 4, ),
([1,2,3,4,5,6,7,8], 'foldl', -36, ),
([1,2,3,4,5,6,7,8], 'foldr', -4, ),
])
@pytest.mark.parametrize('static', [True, False])
def test_foldable_foldl_foldr_result(direction, in_put, exp_result, static):
typ_arg = str(len(in_put)) if static else '...'
# See https://stackoverflow.com/a/13280185
code_py = f"""
def i32_sub(l: i32, r: i32) -> i32:
return l - r
@exported
def testEntry(b: i32[{typ_arg}]) -> i32:
return {direction}(i32_sub, 0, b)
"""
suite = Suite(code_py)
result = suite.run_code(tuple(in_put))
assert exp_result == result.returned_value
@pytest.mark.integration_test
def test_foldable_foldl_bytes():
code_py = """
def u8_or(l: u8, r: u8) -> u8:
return l | r
@exported
def testEntry(b: bytes) -> u8:
return foldl(u8_or, 0, b)
"""
suite = Suite(code_py)
result = suite.run_code(b'')
assert 0 == result.returned_value
result = suite.run_code(b'\x80')
assert 128 == result.returned_value
result = suite.run_code(b'\x80\x40')
assert 192 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10')
assert 240 == result.returned_value
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
assert 255 == result.returned_value
@pytest.mark.integration_test
@pytest.mark.parametrize('in_typ', ['i8', 'i8[3]'])
def test_foldable_argument_must_be_a_function(in_typ):
code_py = f"""
@exported
def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
return foldl(x, y, z)
"""
r_in_typ = in_typ.replace('[', '\\[').replace(']', '\\]')
with pytest.raises(Type3Exception, match=f'{r_in_typ} must be a function instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test
def test_foldable_argument_must_be_right_function():
code_py = """
def foo(l: i32, r: i64) -> i64:
return extend(l) + r
@exported
def testEntry(i: i64, l: i64[3]) -> i64:
return foldr(foo, i, l)
"""
with pytest.raises(Type3Exception, match=r'Callable\[i64, i64, i64\] must be Callable\[i32, i64, i64\] instead'):
Suite(code_py).run_code()
@pytest.mark.integration_test @pytest.mark.integration_test
def test_foldable_invalid_return_type(): def test_foldable_invalid_return_type():