Removes the special casing for foldl

Has to implement both functions as arguments and type
place holders (variables) for type constructors.
This commit is contained in:
Johan B.W. de Vries 2025-04-27 12:54:34 +02:00
parent ac4b46bbe7
commit 23f9e60378
7 changed files with 8 additions and 168 deletions

View File

@ -102,10 +102,6 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}' return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements: def statement(inp: ourlang.Statement) -> Statements:

View File

@ -4,7 +4,7 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
import struct import struct
from typing import List, Optional from typing import List, Optional
from . import codestyle, ourlang, prelude, wasm from . import ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
@ -376,90 +376,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
))) )))
return return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp)
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
"""
Compile: Fold expression
"""
assert inp.type3 is not None, TYPE3_ASSERTION_ERROR
if inp.iter.type3 is not prelude.bytes_:
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
wgn.add_statement('nop', comment='acu :: u8')
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
wgn.add_statement('nop', comment='adr :: bytes*')
adr_var = wgn.temp_var_i32('fold_i32_adr')
wgn.add_statement('nop', comment='len :: i32')
len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base')
expression(wgn, inp.base)
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, inp.iter)
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)')
wgn.local.get(adr_var)
wgn.i32.load()
wgn.local.set(len_var)
wgn.add_statement('nop', comment='i = 0')
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
wgn.i32.const(0)
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i < len')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
with wgn.if_():
# From here on, adr_var is the address of byte we're referencing
# This is akin to calling stdlib_types.__subscript_bytes__
# But since we already know we are inside of bounds,
# can just bypass it and load the memory directly.
wgn.local.get(adr_var)
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
wgn.i32.add()
wgn.local.set(adr_var)
wgn.add_statement('nop', comment='while True')
with wgn.loop():
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
wgn.local.get(acu_var)
# Get the next byte, write back the address
wgn.local.get(adr_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.tee(adr_var)
wgn.i32.load8_u()
wgn.add_statement('call', f'${inp.func.name}')
wgn.local.set(acu_var)
wgn.add_statement('nop', comment='i = i + 1')
wgn.local.get(idx_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i >= len: break')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
wgn.br_if(0)
# return acu
wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None: def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
""" """
Compile: Return statement Compile: Return statement

View File

@ -1,7 +1,6 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
import enum
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from . import prelude from . import prelude
@ -207,36 +206,6 @@ class AccessStructMember(Expression):
self.struct_type3 = struct_type3 self.struct_type3 = struct_type3
self.member = member self.member = member
class Fold(Expression):
"""
A (left or right) fold
"""
class Direction(enum.Enum):
"""
Which direction to fold in
"""
LEFT = 0
RIGHT = 1
dir: Direction
func: 'Function'
base: Expression
iter: Expression
def __init__(
self,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__()
self.dir = dir_
self.func = func
self.base = base
self.iter = iter_
class Statement: class Statement:
""" """
A statement within a function A statement within a function

View File

@ -14,7 +14,6 @@ from .ourlang import (
ConstantStruct, ConstantStruct,
ConstantTuple, ConstantTuple,
Expression, Expression,
Fold,
Function, Function,
FunctionCall, FunctionCall,
FunctionParam, FunctionParam,
@ -462,7 +461,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as expr in FunctionDef') raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]: def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
if node.keywords: if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet? _raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -475,28 +474,6 @@ class OurVisitor:
if node.func.id in PRELUDE_METHODS: if node.func.id in PRELUDE_METHODS:
func = PRELUDE_METHODS[node.func.id] func = PRELUDE_METHODS[node.func.id]
elif node.func.id == 'foldl':
if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic, you cannot return a function
subnode = node.args[0]
if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id]
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
return Fold(
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
)
else: else:
if node.func.id not in module.functions: if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function') _raise_static_error(node, 'Call to undefined function')

View File

@ -574,6 +574,7 @@ instance_type_class(Promotable, f32, f64, methods={
Foldable = Type3Class('Foldable', (t, ), methods={ Foldable = Type3Class('Foldable', (t, ), methods={
'sum': [t(a), a], 'sum': [t(a), a],
'foldl': [[a, b, b], b, t(a), b],
}, operators={}, additional_context={ }, operators={}, additional_context={
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))], 'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
}) })

View File

@ -264,14 +264,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
__slots__ = ('context', 'type_class3', 'types', ) __slots__ = ('context', 'type_class3', 'types', )
context: Context context: Context
type_class3: Union[str, Type3Class] type_class3: Type3Class
types: list[Type3OrPlaceholder] types: list[Type3OrPlaceholder]
DATA = { def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
'bytes': {'Foldable'},
}
def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
self.context = context self.context = context
@ -299,13 +295,9 @@ class MustImplementTypeClassConstraint(ConstraintBase):
assert len(typ_list) == len(self.types) assert len(typ_list) == len(self.types)
if isinstance(self.type_class3, Type3Class): key = (self.type_class3, tuple(typ_list), )
key = (self.type_class3, tuple(typ_list), ) if key in self.context.type_class_instances_existing:
if key in self.context.type_class_instances_existing: return None
return None
else:
if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()):
return None
typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name
typ_name_list = ' '.join(x.name for x in typ_list) typ_name_list = ' '.join(x.name for x in typ_list)

View File

@ -209,19 +209,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
return return
if isinstance(inp, ourlang.Fold):
base_phft = PlaceholderForType([inp.base])
iter_phft = PlaceholderForType([inp.iter])
yield from expression(ctx, inp.base, base_phft)
yield from expression(ctx, inp.iter, iter_phft)
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft,
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft])
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: