Removes the special casing for foldl

Has to implement both functions as arguments and type
place holders (variables) for type constructors.
This commit is contained in:
Johan B.W. de Vries 2025-04-27 12:54:34 +02:00
parent ac4b46bbe7
commit 23f9e60378
7 changed files with 8 additions and 168 deletions

View File

@ -102,10 +102,6 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements:

View File

@ -4,7 +4,7 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
import struct
from typing import List, Optional
from . import codestyle, ourlang, prelude, wasm
from . import ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types
@ -376,90 +376,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
)))
return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp)
return
raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
"""
Compile: Fold expression
"""
assert inp.type3 is not None, TYPE3_ASSERTION_ERROR
if inp.iter.type3 is not prelude.bytes_:
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
wgn.add_statement('nop', comment='acu :: u8')
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
wgn.add_statement('nop', comment='adr :: bytes*')
adr_var = wgn.temp_var_i32('fold_i32_adr')
wgn.add_statement('nop', comment='len :: i32')
len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base')
expression(wgn, inp.base)
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, inp.iter)
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)')
wgn.local.get(adr_var)
wgn.i32.load()
wgn.local.set(len_var)
wgn.add_statement('nop', comment='i = 0')
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
wgn.i32.const(0)
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i < len')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
with wgn.if_():
# From here on, adr_var is the address of byte we're referencing
# This is akin to calling stdlib_types.__subscript_bytes__
# But since we already know we are inside of bounds,
# can just bypass it and load the memory directly.
wgn.local.get(adr_var)
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
wgn.i32.add()
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='while True')
with wgn.loop():
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
wgn.local.get(acu_var)
# Get the next byte, write back the address
wgn.local.get(adr_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.tee(adr_var)
wgn.i32.load8_u()
wgn.add_statement('call', f'${inp.func.name}')
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='i = i + 1')
wgn.local.get(idx_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i >= len: break')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
wgn.br_if(0)
# return acu
wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
"""
Compile: Return statement

View File

@ -1,7 +1,6 @@
"""
Contains the syntax tree for ourlang
"""
import enum
from typing import Dict, Iterable, List, Optional, Union
from . import prelude
@ -207,36 +206,6 @@ class AccessStructMember(Expression):
self.struct_type3 = struct_type3
self.member = member
class Fold(Expression):
"""
A (left or right) fold
"""
class Direction(enum.Enum):
"""
Which direction to fold in
"""
LEFT = 0
RIGHT = 1
dir: Direction
func: 'Function'
base: Expression
iter: Expression
def __init__(
self,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__()
self.dir = dir_
self.func = func
self.base = base
self.iter = iter_
class Statement:
"""
A statement within a function

View File

@ -14,7 +14,6 @@ from .ourlang import (
ConstantStruct,
ConstantTuple,
Expression,
Fold,
Function,
FunctionCall,
FunctionParam,
@ -462,7 +461,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]:
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -475,28 +474,6 @@ class OurVisitor:
if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id]
elif node.func.id == 'foldl':
if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic, you cannot return a function
subnode = node.args[0]
if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id]
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
return Fold(
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
)
else:
if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function')

View File

@ -574,6 +574,7 @@ instance_type_class(Promotable, f32, f64, methods={
Foldable = Type3Class('Foldable', (t, ), methods={
'sum': [t(a), a],
'foldl': [[a, b, b], b, t(a), b],
}, operators={}, additional_context={
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
})

View File

@ -264,14 +264,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
__slots__ = ('context', 'type_class3', 'types', )
context: Context
type_class3: Union[str, Type3Class]
type_class3: Type3Class
types: list[Type3OrPlaceholder]
DATA = {
'bytes': {'Foldable'},
}
def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
self.context = context
@ -299,13 +295,9 @@ class MustImplementTypeClassConstraint(ConstraintBase):
assert len(typ_list) == len(self.types)
if isinstance(self.type_class3, Type3Class):
key = (self.type_class3, tuple(typ_list), )
if key in self.context.type_class_instances_existing:
return None
else:
if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()):
return None
typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name
typ_name_list = ' '.join(x.name for x in typ_list)

View File

@ -209,19 +209,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
return
if isinstance(inp, ourlang.Fold):
base_phft = PlaceholderForType([inp.base])
iter_phft = PlaceholderForType([inp.iter])
yield from expression(ctx, inp.base, base_phft)
yield from expression(ctx, inp.iter, iter_phft)
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft,
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft])
return
raise NotImplementedError(expression, inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: