Removes the special casing for foldl

Has to implement both functions as arguments and type
place holders (variables) for type constructors.
This commit is contained in:
Johan B.W. de Vries 2025-04-27 12:54:34 +02:00
parent d3e38b96b2
commit b40a6a4cdd
8 changed files with 131 additions and 156 deletions

View File

@ -116,10 +116,6 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements:

View File

@ -4,7 +4,7 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
import struct
from typing import Dict, List, Optional
from . import codestyle, ourlang, prelude, wasm
from . import ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types
@ -587,89 +587,85 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
)))
return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp)
return
raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
"""
Compile: Fold expression
"""
assert isinstance(inp.type3, type3types.Type3), type3placeholders.TYPE3_ASSERTION_ERROR
# def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
# """
# Compile: Fold expression
# """
# assert isinstance(inp.type3, type3types.Type3), type3placeholders.TYPE3_ASSERTION_ERROR
if inp.iter.type3 is not prelude.bytes_:
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
# if inp.iter.type3 is not prelude.bytes_:
# raise NotImplementedError(expression_fold, inp, inp.iter.type3)
wgn.add_statement('nop', comment='acu :: u8')
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
wgn.add_statement('nop', comment='adr :: bytes*')
adr_var = wgn.temp_var_i32('fold_i32_adr')
wgn.add_statement('nop', comment='len :: i32')
len_var = wgn.temp_var_i32('fold_i32_len')
# wgn.add_statement('nop', comment='acu :: u8')
# acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
# wgn.add_statement('nop', comment='adr :: bytes*')
# adr_var = wgn.temp_var_i32('fold_i32_adr')
# wgn.add_statement('nop', comment='len :: i32')
# len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base')
expression(wgn, inp.base)
wgn.local.set(acu_var)
# wgn.add_statement('nop', comment='acu = base')
# expression(wgn, inp.base)
# wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, inp.iter)
wgn.local.set(adr_var)
# wgn.add_statement('nop', comment='adr = adr(iter)')
# expression(wgn, inp.iter)
# wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)')
wgn.local.get(adr_var)
wgn.i32.load()
wgn.local.set(len_var)
# wgn.add_statement('nop', comment='len = len(iter)')
# wgn.local.get(adr_var)
# wgn.i32.load()
# wgn.local.set(len_var)
wgn.add_statement('nop', comment='i = 0')
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
wgn.i32.const(0)
wgn.local.set(idx_var)
# wgn.add_statement('nop', comment='i = 0')
# idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
# wgn.i32.const(0)
# wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i < len')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
with wgn.if_():
# From here on, adr_var is the address of byte we're referencing
# This is akin to calling stdlib_types.__subscript_bytes__
# But since we already know we are inside of bounds,
# can just bypass it and load the memory directly.
wgn.local.get(adr_var)
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
wgn.i32.add()
wgn.local.set(adr_var)
# wgn.add_statement('nop', comment='if i < len')
# wgn.local.get(idx_var)
# wgn.local.get(len_var)
# wgn.i32.lt_u()
# with wgn.if_():
# # From here on, adr_var is the address of byte we're referencing
# # This is akin to calling stdlib_types.__subscript_bytes__
# # But since we already know we are inside of bounds,
# # can just bypass it and load the memory directly.
# wgn.local.get(adr_var)
# wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
# wgn.i32.add()
# wgn.local.set(adr_var)
wgn.add_statement('nop', comment='while True')
with wgn.loop():
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
wgn.local.get(acu_var)
# wgn.add_statement('nop', comment='while True')
# with wgn.loop():
# wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
# wgn.local.get(acu_var)
# Get the next byte, write back the address
wgn.local.get(adr_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.tee(adr_var)
wgn.i32.load8_u()
# # Get the next byte, write back the address
# wgn.local.get(adr_var)
# wgn.i32.const(1)
# wgn.i32.add()
# wgn.local.tee(adr_var)
# wgn.i32.load8_u()
wgn.add_statement('call', f'${inp.func.name}')
wgn.local.set(acu_var)
# wgn.add_statement('call', f'${inp.func.name}')
# wgn.local.set(acu_var)
wgn.add_statement('nop', comment='i = i + 1')
wgn.local.get(idx_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.set(idx_var)
# wgn.add_statement('nop', comment='i = i + 1')
# wgn.local.get(idx_var)
# wgn.i32.const(1)
# wgn.i32.add()
# wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i >= len: break')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
wgn.br_if(0)
# wgn.add_statement('nop', comment='if i >= len: break')
# wgn.local.get(idx_var)
# wgn.local.get(len_var)
# wgn.i32.lt_u()
# wgn.br_if(0)
# return acu
wgn.local.get(acu_var)
# # return acu
# wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
"""

View File

@ -1,7 +1,6 @@
"""
Contains the syntax tree for ourlang
"""
import enum
from typing import Dict, Iterable, List, Optional, Union
from . import prelude
@ -222,36 +221,6 @@ class AccessStructMember(Expression):
self.struct_type3 = struct_type3
self.member = member
class Fold(Expression):
"""
A (left or right) fold
"""
class Direction(enum.Enum):
"""
Which direction to fold in
"""
LEFT = 0
RIGHT = 1
dir: Direction
func: 'Function'
base: Expression
iter: Expression
def __init__(
self,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__()
self.dir = dir_
self.func = func
self.base = base
self.iter = iter_
class Statement:
"""
A statement within a function

View File

@ -14,7 +14,6 @@ from .ourlang import (
ConstantStruct,
ConstantTuple,
Expression,
Fold,
Function,
FunctionCall,
FunctionParam,
@ -467,7 +466,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]:
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall, UnaryOp]:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -490,28 +489,28 @@ class OurVisitor:
)
unary_op.type3 = prelude.u32
return unary_op
elif node.func.id == 'foldl':
if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# elif node.func.id == 'foldl':
# if 3 != len(node.args):
# _raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic, you cannot return a function
subnode = node.args[0]
if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id]
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
# # TODO: This is not generic, you cannot return a function
# subnode = node.args[0]
# if not isinstance(subnode, ast.Name):
# raise NotImplementedError(f'Calling methods that are not a name {subnode}')
# if not isinstance(subnode.ctx, ast.Load):
# _raise_static_error(subnode, 'Must be load context')
# if subnode.id not in module.functions:
# _raise_static_error(subnode, 'Reference to undefined function')
# func = module.functions[subnode.id]
# if 2 != len(func.posonlyargs):
# _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
return Fold(
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
)
# return Fold(
# Fold.Direction.LEFT,
# func,
# self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
# self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
# )
else:
if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function')

View File

@ -2,7 +2,12 @@
The prelude are all the builtin types, type classes and methods
"""
from ..type3.typeclasses import Type3Class, TypeVariable, instance_type_class
from ..type3.typeclasses import (
Type3Class,
TypeConstructorVariable,
TypeVariable,
instance_type_class,
)
from ..type3.types import (
Type3,
TypeConstructor_StaticArray,
@ -121,7 +126,8 @@ PRELUDE_TYPES: dict[str, Type3] = {
}
a = TypeVariable('a')
b = TypeVariable('b')
t = TypeConstructorVariable('t')
InternalPassAsPointer = Type3Class('InternalPassAsPointer', [a], methods={}, operators={})
"""
@ -247,6 +253,10 @@ Sized_ = Type3Class('Sized', [a], methods={
instance_type_class(Sized_, bytes_)
Foldable = Type3Class('Foldable', [t], methods={
'foldl': [[a, b, b], b, t(a), b],
}, operators={})
PRELUDE_TYPE_CLASSES = {
'Eq': Eq,
'Ord': Ord,

View File

@ -241,14 +241,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
"""
__slots__ = ('type_class3', 'type3', )
type_class3: Union[str, typeclasses.Type3Class]
type_class3: typeclasses.Type3Class
type3: placeholders.Type3OrPlaceholder
DATA = {
'bytes': {'Foldable'},
}
def __init__(self, type_class3: Union[str, typeclasses.Type3Class], type3: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
def __init__(self, type_class3: typeclasses.Type3Class, type3: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
self.type_class3 = type_class3
@ -262,12 +258,8 @@ class MustImplementTypeClassConstraint(ConstraintBase):
if isinstance(typ, placeholders.PlaceholderForType):
return RequireTypeSubstitutes()
if isinstance(self.type_class3, typeclasses.Type3Class):
if self.type_class3 in typ.classes:
return None
else:
if self.type_class3 in self.__class__.DATA.get(typ.name, set()):
return None
if self.type_class3 in typ.classes:
return None
return Error(f'{typ.name} does not implement the {self.type_class3} type class')

View File

@ -161,16 +161,6 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator:
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
return
if isinstance(inp, ourlang.Fold):
yield from expression(ctx, inp.base)
yield from expression(ctx, inp.iter)
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, inp.base.type3, inp.type3,
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
yield MustImplementTypeClassConstraint('Foldable', inp.iter.type3)
return
raise NotImplementedError(expression, inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:

View File

@ -24,6 +24,27 @@ class TypeVariable:
def __repr__(self) -> str:
return f'TypeVariable({repr(self.letter)})'
class TypeConstructorVariable:
__slots__ = ('letter', )
letter: str
def __init__(self, letter: str) -> None:
assert len(letter) == 1, f'{letter} is not a valid type variable'
self.letter = letter
def __hash__(self) -> int:
return hash(self.letter)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, TypeVariable):
raise NotImplementedError
return self.letter == other.letter
def __repr__(self) -> str:
return f'TypeVariable({repr(self.letter)})'
class TypeReference:
__slots__ = ('name', )
@ -45,14 +66,16 @@ class TypeReference:
def __repr__(self) -> str:
return f'TypeReference({repr(self.name)})'
Signature = List[Union[Type3, TypeVariable, list[TypeVariable]]]
class Type3ClassMethod:
__slots__ = ('type3_class', 'name', 'signature', )
type3_class: 'Type3Class'
name: str
signature: List[Union[Type3, TypeVariable]]
signature: Signature
def __init__(self, type3_class: 'Type3Class', name: str, signature: Iterable[Union[Type3, TypeVariable]]) -> None:
def __init__(self, type3_class: 'Type3Class', name: str, signature: Signature) -> None:
self.type3_class = type3_class
self.name = name
self.signature = list(signature)
@ -64,7 +87,7 @@ class Type3Class:
__slots__ = ('name', 'args', 'methods', 'operators', 'inherited_classes', )
name: str
args: List[TypeVariable]
args: List[Union[TypeVariable, TypeConstructorVariable]]
methods: Dict[str, Type3ClassMethod]
operators: Dict[str, Type3ClassMethod]
inherited_classes: List['Type3Class']
@ -72,9 +95,9 @@ class Type3Class:
def __init__(
self,
name: str,
args: Iterable[TypeVariable],
methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]],
operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]],
args: Iterable[Union[TypeVariable, TypeConstructorVariable]],
methods: Mapping[str, Signature],
operators: Mapping[str, Signature],
inherited_classes: Optional[List['Type3Class']] = None,
) -> None:
self.name = name