Removes the special casing for foldl

Has to implement both functions as arguments and type
place holders (variables) for type constructors.
This commit is contained in:
Johan B.W. de Vries 2025-04-27 12:54:34 +02:00
parent d3e38b96b2
commit b40a6a4cdd
8 changed files with 131 additions and 156 deletions

View File

@ -116,10 +116,6 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.AccessStructMember): if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}' return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements: def statement(inp: ourlang.Statement) -> Statements:

View File

@ -4,7 +4,7 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
import struct import struct
from typing import Dict, List, Optional from typing import Dict, List, Optional
from . import codestyle, ourlang, prelude, wasm from . import ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types from .stdlib import types as stdlib_types
@ -587,89 +587,85 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
))) )))
return return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp)
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: # def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
""" # """
Compile: Fold expression # Compile: Fold expression
""" # """
assert isinstance(inp.type3, type3types.Type3), type3placeholders.TYPE3_ASSERTION_ERROR # assert isinstance(inp.type3, type3types.Type3), type3placeholders.TYPE3_ASSERTION_ERROR
if inp.iter.type3 is not prelude.bytes_: # if inp.iter.type3 is not prelude.bytes_:
raise NotImplementedError(expression_fold, inp, inp.iter.type3) # raise NotImplementedError(expression_fold, inp, inp.iter.type3)
wgn.add_statement('nop', comment='acu :: u8') # wgn.add_statement('nop', comment='acu :: u8')
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu') # acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
wgn.add_statement('nop', comment='adr :: bytes*') # wgn.add_statement('nop', comment='adr :: bytes*')
adr_var = wgn.temp_var_i32('fold_i32_adr') # adr_var = wgn.temp_var_i32('fold_i32_adr')
wgn.add_statement('nop', comment='len :: i32') # wgn.add_statement('nop', comment='len :: i32')
len_var = wgn.temp_var_i32('fold_i32_len') # len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base') # wgn.add_statement('nop', comment='acu = base')
expression(wgn, inp.base) # expression(wgn, inp.base)
wgn.local.set(acu_var) # wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)') # wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, inp.iter) # expression(wgn, inp.iter)
wgn.local.set(adr_var) # wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)') # wgn.add_statement('nop', comment='len = len(iter)')
wgn.local.get(adr_var) # wgn.local.get(adr_var)
wgn.i32.load() # wgn.i32.load()
wgn.local.set(len_var) # wgn.local.set(len_var)
wgn.add_statement('nop', comment='i = 0') # wgn.add_statement('nop', comment='i = 0')
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx') # idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
wgn.i32.const(0) # wgn.i32.const(0)
wgn.local.set(idx_var) # wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i < len') # wgn.add_statement('nop', comment='if i < len')
wgn.local.get(idx_var) # wgn.local.get(idx_var)
wgn.local.get(len_var) # wgn.local.get(len_var)
wgn.i32.lt_u() # wgn.i32.lt_u()
with wgn.if_(): # with wgn.if_():
# From here on, adr_var is the address of byte we're referencing # # From here on, adr_var is the address of byte we're referencing
# This is akin to calling stdlib_types.__subscript_bytes__ # # This is akin to calling stdlib_types.__subscript_bytes__
# But since we already know we are inside of bounds, # # But since we already know we are inside of bounds,
# can just bypass it and load the memory directly. # # can just bypass it and load the memory directly.
wgn.local.get(adr_var) # wgn.local.get(adr_var)
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop # wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
wgn.i32.add() # wgn.i32.add()
wgn.local.set(adr_var) # wgn.local.set(adr_var)
wgn.add_statement('nop', comment='while True') # wgn.add_statement('nop', comment='while True')
with wgn.loop(): # with wgn.loop():
wgn.add_statement('nop', comment='acu = func(acu, iter[i])') # wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
wgn.local.get(acu_var) # wgn.local.get(acu_var)
# Get the next byte, write back the address # # Get the next byte, write back the address
wgn.local.get(adr_var) # wgn.local.get(adr_var)
wgn.i32.const(1) # wgn.i32.const(1)
wgn.i32.add() # wgn.i32.add()
wgn.local.tee(adr_var) # wgn.local.tee(adr_var)
wgn.i32.load8_u() # wgn.i32.load8_u()
wgn.add_statement('call', f'${inp.func.name}') # wgn.add_statement('call', f'${inp.func.name}')
wgn.local.set(acu_var) # wgn.local.set(acu_var)
wgn.add_statement('nop', comment='i = i + 1') # wgn.add_statement('nop', comment='i = i + 1')
wgn.local.get(idx_var) # wgn.local.get(idx_var)
wgn.i32.const(1) # wgn.i32.const(1)
wgn.i32.add() # wgn.i32.add()
wgn.local.set(idx_var) # wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i >= len: break') # wgn.add_statement('nop', comment='if i >= len: break')
wgn.local.get(idx_var) # wgn.local.get(idx_var)
wgn.local.get(len_var) # wgn.local.get(len_var)
wgn.i32.lt_u() # wgn.i32.lt_u()
wgn.br_if(0) # wgn.br_if(0)
# return acu # # return acu
wgn.local.get(acu_var) # wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None: def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
""" """

View File

@ -1,7 +1,6 @@
""" """
Contains the syntax tree for ourlang Contains the syntax tree for ourlang
""" """
import enum
from typing import Dict, Iterable, List, Optional, Union from typing import Dict, Iterable, List, Optional, Union
from . import prelude from . import prelude
@ -222,36 +221,6 @@ class AccessStructMember(Expression):
self.struct_type3 = struct_type3 self.struct_type3 = struct_type3
self.member = member self.member = member
class Fold(Expression):
"""
A (left or right) fold
"""
class Direction(enum.Enum):
"""
Which direction to fold in
"""
LEFT = 0
RIGHT = 1
dir: Direction
func: 'Function'
base: Expression
iter: Expression
def __init__(
self,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__()
self.dir = dir_
self.func = func
self.base = base
self.iter = iter_
class Statement: class Statement:
""" """
A statement within a function A statement within a function

View File

@ -14,7 +14,6 @@ from .ourlang import (
ConstantStruct, ConstantStruct,
ConstantTuple, ConstantTuple,
Expression, Expression,
Fold,
Function, Function,
FunctionCall, FunctionCall,
FunctionParam, FunctionParam,
@ -467,7 +466,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as expr in FunctionDef') raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]: def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall, UnaryOp]:
if node.keywords: if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet? _raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -490,28 +489,28 @@ class OurVisitor:
) )
unary_op.type3 = prelude.u32 unary_op.type3 = prelude.u32
return unary_op return unary_op
elif node.func.id == 'foldl': # elif node.func.id == 'foldl':
if 3 != len(node.args): # if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given') # _raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic, you cannot return a function # # TODO: This is not generic, you cannot return a function
subnode = node.args[0] # subnode = node.args[0]
if not isinstance(subnode, ast.Name): # if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}') # raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load): # if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context') # _raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions: # if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function') # _raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id] # func = module.functions[subnode.id]
if 2 != len(func.posonlyargs): # if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') # _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
return Fold( # return Fold(
Fold.Direction.LEFT, # Fold.Direction.LEFT,
func, # func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), # self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), # self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
) # )
else: else:
if node.func.id not in module.functions: if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function') _raise_static_error(node, 'Call to undefined function')

View File

@ -2,7 +2,12 @@
The prelude are all the builtin types, type classes and methods The prelude are all the builtin types, type classes and methods
""" """
from ..type3.typeclasses import Type3Class, TypeVariable, instance_type_class from ..type3.typeclasses import (
Type3Class,
TypeConstructorVariable,
TypeVariable,
instance_type_class,
)
from ..type3.types import ( from ..type3.types import (
Type3, Type3,
TypeConstructor_StaticArray, TypeConstructor_StaticArray,
@ -121,7 +126,8 @@ PRELUDE_TYPES: dict[str, Type3] = {
} }
a = TypeVariable('a') a = TypeVariable('a')
b = TypeVariable('b')
t = TypeConstructorVariable('t')
InternalPassAsPointer = Type3Class('InternalPassAsPointer', [a], methods={}, operators={}) InternalPassAsPointer = Type3Class('InternalPassAsPointer', [a], methods={}, operators={})
""" """
@ -247,6 +253,10 @@ Sized_ = Type3Class('Sized', [a], methods={
instance_type_class(Sized_, bytes_) instance_type_class(Sized_, bytes_)
Foldable = Type3Class('Foldable', [t], methods={
'foldl': [[a, b, b], b, t(a), b],
}, operators={})
PRELUDE_TYPE_CLASSES = { PRELUDE_TYPE_CLASSES = {
'Eq': Eq, 'Eq': Eq,
'Ord': Ord, 'Ord': Ord,

View File

@ -241,14 +241,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
""" """
__slots__ = ('type_class3', 'type3', ) __slots__ = ('type_class3', 'type3', )
type_class3: Union[str, typeclasses.Type3Class] type_class3: typeclasses.Type3Class
type3: placeholders.Type3OrPlaceholder type3: placeholders.Type3OrPlaceholder
DATA = { def __init__(self, type_class3: typeclasses.Type3Class, type3: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
'bytes': {'Foldable'},
}
def __init__(self, type_class3: Union[str, typeclasses.Type3Class], type3: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
super().__init__(comment=comment) super().__init__(comment=comment)
self.type_class3 = type_class3 self.type_class3 = type_class3
@ -262,12 +258,8 @@ class MustImplementTypeClassConstraint(ConstraintBase):
if isinstance(typ, placeholders.PlaceholderForType): if isinstance(typ, placeholders.PlaceholderForType):
return RequireTypeSubstitutes() return RequireTypeSubstitutes()
if isinstance(self.type_class3, typeclasses.Type3Class):
if self.type_class3 in typ.classes: if self.type_class3 in typ.classes:
return None return None
else:
if self.type_class3 in self.__class__.DATA.get(typ.name, set()):
return None
return Error(f'{typ.name} does not implement the {self.type_class3} type class') return Error(f'{typ.name} does not implement the {self.type_class3} type class')

View File

@ -161,16 +161,6 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator:
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
return return
if isinstance(inp, ourlang.Fold):
yield from expression(ctx, inp.base)
yield from expression(ctx, inp.iter)
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, inp.base.type3, inp.type3,
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
yield MustImplementTypeClassConstraint('Foldable', inp.iter.type3)
return
raise NotImplementedError(expression, inp) raise NotImplementedError(expression, inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:

View File

@ -24,6 +24,27 @@ class TypeVariable:
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TypeVariable({repr(self.letter)})' return f'TypeVariable({repr(self.letter)})'
class TypeConstructorVariable:
__slots__ = ('letter', )
letter: str
def __init__(self, letter: str) -> None:
assert len(letter) == 1, f'{letter} is not a valid type variable'
self.letter = letter
def __hash__(self) -> int:
return hash(self.letter)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeVariable):
raise NotImplementedError
return self.letter == other.letter
def __repr__(self) -> str:
return f'TypeVariable({repr(self.letter)})'
class TypeReference: class TypeReference:
__slots__ = ('name', ) __slots__ = ('name', )
@ -45,14 +66,16 @@ class TypeReference:
def __repr__(self) -> str: def __repr__(self) -> str:
return f'TypeReference({repr(self.name)})' return f'TypeReference({repr(self.name)})'
Signature = List[Union[Type3, TypeVariable, list[TypeVariable]]]
class Type3ClassMethod: class Type3ClassMethod:
__slots__ = ('type3_class', 'name', 'signature', ) __slots__ = ('type3_class', 'name', 'signature', )
type3_class: 'Type3Class' type3_class: 'Type3Class'
name: str name: str
signature: List[Union[Type3, TypeVariable]] signature: Signature
def __init__(self, type3_class: 'Type3Class', name: str, signature: Iterable[Union[Type3, TypeVariable]]) -> None: def __init__(self, type3_class: 'Type3Class', name: str, signature: Signature) -> None:
self.type3_class = type3_class self.type3_class = type3_class
self.name = name self.name = name
self.signature = list(signature) self.signature = list(signature)
@ -64,7 +87,7 @@ class Type3Class:
__slots__ = ('name', 'args', 'methods', 'operators', 'inherited_classes', ) __slots__ = ('name', 'args', 'methods', 'operators', 'inherited_classes', )
name: str name: str
args: List[TypeVariable] args: List[Union[TypeVariable, TypeConstructorVariable]]
methods: Dict[str, Type3ClassMethod] methods: Dict[str, Type3ClassMethod]
operators: Dict[str, Type3ClassMethod] operators: Dict[str, Type3ClassMethod]
inherited_classes: List['Type3Class'] inherited_classes: List['Type3Class']
@ -72,9 +95,9 @@ class Type3Class:
def __init__( def __init__(
self, self,
name: str, name: str,
args: Iterable[TypeVariable], args: Iterable[Union[TypeVariable, TypeConstructorVariable]],
methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]], methods: Mapping[str, Signature],
operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]], operators: Mapping[str, Signature],
inherited_classes: Optional[List['Type3Class']] = None, inherited_classes: Optional[List['Type3Class']] = None,
) -> None: ) -> None:
self.name = name self.name = name