Compare commits

...

1 Commits

Author SHA1 Message Date
Johan B.W. de Vries
c2eaaa7e4a Removes the special casing for foldl
Has to implement both functions as arguments and type
place holders (variables) for type constructors.
2025-04-27 15:34:10 +02:00
7 changed files with 102 additions and 151 deletions

View File

@ -116,10 +116,6 @@ def expression(inp: ourlang.Expression) -> str:
if isinstance(inp, ourlang.AccessStructMember):
return f'{expression(inp.varref)}.{inp.member}'
if isinstance(inp, ourlang.Fold):
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
raise NotImplementedError(expression, inp)
def statement(inp: ourlang.Statement) -> Statements:

View File

@ -4,7 +4,7 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
import struct
from typing import Dict, List, Optional
from . import codestyle, ourlang, prelude, wasm
from . import ourlang, prelude, wasm
from .runtime import calculate_alloc_size, calculate_member_offset
from .stdlib import alloc as stdlib_alloc
from .stdlib import types as stdlib_types
@ -588,89 +588,85 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
)))
return
if isinstance(inp, ourlang.Fold):
expression_fold(wgn, inp)
return
raise NotImplementedError(expression, inp)
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
"""
Compile: Fold expression
"""
assert isinstance(inp.type3, type3types.Type3), type3placeholders.TYPE3_ASSERTION_ERROR
# def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
# """
# Compile: Fold expression
# """
# assert isinstance(inp.type3, type3types.Type3), type3placeholders.TYPE3_ASSERTION_ERROR
if inp.iter.type3 is not prelude.bytes_:
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
# if inp.iter.type3 is not prelude.bytes_:
# raise NotImplementedError(expression_fold, inp, inp.iter.type3)
wgn.add_statement('nop', comment='acu :: u8')
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
wgn.add_statement('nop', comment='adr :: bytes*')
adr_var = wgn.temp_var_i32('fold_i32_adr')
wgn.add_statement('nop', comment='len :: i32')
len_var = wgn.temp_var_i32('fold_i32_len')
# wgn.add_statement('nop', comment='acu :: u8')
# acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
# wgn.add_statement('nop', comment='adr :: bytes*')
# adr_var = wgn.temp_var_i32('fold_i32_adr')
# wgn.add_statement('nop', comment='len :: i32')
# len_var = wgn.temp_var_i32('fold_i32_len')
wgn.add_statement('nop', comment='acu = base')
expression(wgn, inp.base)
wgn.local.set(acu_var)
# wgn.add_statement('nop', comment='acu = base')
# expression(wgn, inp.base)
# wgn.local.set(acu_var)
wgn.add_statement('nop', comment='adr = adr(iter)')
expression(wgn, inp.iter)
wgn.local.set(adr_var)
# wgn.add_statement('nop', comment='adr = adr(iter)')
# expression(wgn, inp.iter)
# wgn.local.set(adr_var)
wgn.add_statement('nop', comment='len = len(iter)')
wgn.local.get(adr_var)
wgn.i32.load()
wgn.local.set(len_var)
# wgn.add_statement('nop', comment='len = len(iter)')
# wgn.local.get(adr_var)
# wgn.i32.load()
# wgn.local.set(len_var)
wgn.add_statement('nop', comment='i = 0')
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
wgn.i32.const(0)
wgn.local.set(idx_var)
# wgn.add_statement('nop', comment='i = 0')
# idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
# wgn.i32.const(0)
# wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i < len')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
with wgn.if_():
# From here on, adr_var is the address of byte we're referencing
# This is akin to calling stdlib_types.__subscript_bytes__
# But since we already know we are inside of bounds,
# can just bypass it and load the memory directly.
wgn.local.get(adr_var)
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
wgn.i32.add()
wgn.local.set(adr_var)
# wgn.add_statement('nop', comment='if i < len')
# wgn.local.get(idx_var)
# wgn.local.get(len_var)
# wgn.i32.lt_u()
# with wgn.if_():
# # From here on, adr_var is the address of byte we're referencing
# # This is akin to calling stdlib_types.__subscript_bytes__
# # But since we already know we are inside of bounds,
# # can just bypass it and load the memory directly.
# wgn.local.get(adr_var)
# wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
# wgn.i32.add()
# wgn.local.set(adr_var)
wgn.add_statement('nop', comment='while True')
with wgn.loop():
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
wgn.local.get(acu_var)
# wgn.add_statement('nop', comment='while True')
# with wgn.loop():
# wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
# wgn.local.get(acu_var)
# Get the next byte, write back the address
wgn.local.get(adr_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.tee(adr_var)
wgn.i32.load8_u()
# # Get the next byte, write back the address
# wgn.local.get(adr_var)
# wgn.i32.const(1)
# wgn.i32.add()
# wgn.local.tee(adr_var)
# wgn.i32.load8_u()
wgn.add_statement('call', f'${inp.func.name}')
wgn.local.set(acu_var)
# wgn.add_statement('call', f'${inp.func.name}')
# wgn.local.set(acu_var)
wgn.add_statement('nop', comment='i = i + 1')
wgn.local.get(idx_var)
wgn.i32.const(1)
wgn.i32.add()
wgn.local.set(idx_var)
# wgn.add_statement('nop', comment='i = i + 1')
# wgn.local.get(idx_var)
# wgn.i32.const(1)
# wgn.i32.add()
# wgn.local.set(idx_var)
wgn.add_statement('nop', comment='if i >= len: break')
wgn.local.get(idx_var)
wgn.local.get(len_var)
wgn.i32.lt_u()
wgn.br_if(0)
# wgn.add_statement('nop', comment='if i >= len: break')
# wgn.local.get(idx_var)
# wgn.local.get(len_var)
# wgn.i32.lt_u()
# wgn.br_if(0)
# return acu
wgn.local.get(acu_var)
# # return acu
# wgn.local.get(acu_var)
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
"""

View File

@ -1,7 +1,6 @@
"""
Contains the syntax tree for ourlang
"""
import enum
from typing import Dict, Iterable, List, Optional, Union
from . import prelude
@ -223,36 +222,6 @@ class AccessStructMember(Expression):
self.struct_type3 = struct_type3
self.member = member
class Fold(Expression):
"""
A (left or right) fold
"""
class Direction(enum.Enum):
"""
Which direction to fold in
"""
LEFT = 0
RIGHT = 1
dir: Direction
func: 'Function'
base: Expression
iter: Expression
def __init__(
self,
dir_: Direction,
func: 'Function',
base: Expression,
iter_: Expression,
) -> None:
super().__init__()
self.dir = dir_
self.func = func
self.base = base
self.iter = iter_
class Statement:
"""
A statement within a function

View File

@ -14,7 +14,6 @@ from .ourlang import (
ConstantStruct,
ConstantTuple,
Expression,
Fold,
Function,
FunctionCall,
FunctionParam,
@ -476,7 +475,7 @@ class OurVisitor:
raise NotImplementedError(f'{node} as expr in FunctionDef')
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall, UnaryOp]:
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall, UnaryOp]:
if node.keywords:
_raise_static_error(node, 'Keyword calling not supported') # Yet?
@ -499,28 +498,28 @@ class OurVisitor:
)
unary_op.type3 = prelude.u32
return unary_op
elif node.func.id == 'foldl':
if 3 != len(node.args):
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# elif node.func.id == 'foldl':
# if 3 != len(node.args):
# _raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
# TODO: This is not generic, you cannot return a function
subnode = node.args[0]
if not isinstance(subnode, ast.Name):
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
if not isinstance(subnode.ctx, ast.Load):
_raise_static_error(subnode, 'Must be load context')
if subnode.id not in module.functions:
_raise_static_error(subnode, 'Reference to undefined function')
func = module.functions[subnode.id]
if 2 != len(func.posonlyargs):
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
# # TODO: This is not generic, you cannot return a function
# subnode = node.args[0]
# if not isinstance(subnode, ast.Name):
# raise NotImplementedError(f'Calling methods that are not a name {subnode}')
# if not isinstance(subnode.ctx, ast.Load):
# _raise_static_error(subnode, 'Must be load context')
# if subnode.id not in module.functions:
# _raise_static_error(subnode, 'Reference to undefined function')
# func = module.functions[subnode.id]
# if 2 != len(func.posonlyargs):
# _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
return Fold(
Fold.Direction.LEFT,
func,
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
)
# return Fold(
# Fold.Direction.LEFT,
# func,
# self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
# self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
# )
else:
if node.func.id not in module.functions:
_raise_static_error(node, 'Call to undefined function')

View File

@ -2,8 +2,12 @@
The prelude are all the builtin types, type classes and methods
"""
from ..type3.functions import TypeVariable
from ..type3.typeclasses import Type3Class, instance_type_class
from ..type3.typeclasses import (
Type3Class,
TypeConstructorVariable,
TypeVariable,
instance_type_class,
)
from ..type3.types import (
Type3,
TypeConstructor_StaticArray,
@ -122,7 +126,8 @@ PRELUDE_TYPES: dict[str, Type3] = {
}
a = TypeVariable('a')
b = TypeVariable('b')
t = TypeConstructorVariable('t')
InternalPassAsPointer = Type3Class('InternalPassAsPointer', [a], methods={}, operators={})
"""
@ -248,6 +253,10 @@ Sized_ = Type3Class('Sized', [a], methods={
instance_type_class(Sized_, bytes_)
Foldable = Type3Class('Foldable', [t], methods={
'foldl': [[a, b, b], b, t(a), b],
}, operators={})
PRELUDE_TYPE_CLASSES = {
'Eq': Eq,
'Ord': Ord,

View File

@ -241,14 +241,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
"""
__slots__ = ('type_class3', 'type3', )
type_class3: Union[str, typeclasses.Type3Class]
type_class3: typeclasses.Type3Class
type3: placeholders.Type3OrPlaceholder
DATA = {
'bytes': {'Foldable'},
}
def __init__(self, type_class3: Union[str, typeclasses.Type3Class], type3: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
def __init__(self, type_class3: typeclasses.Type3Class, type3: placeholders.Type3OrPlaceholder, comment: Optional[str] = None) -> None:
super().__init__(comment=comment)
self.type_class3 = type_class3
@ -262,12 +258,8 @@ class MustImplementTypeClassConstraint(ConstraintBase):
if isinstance(typ, placeholders.PlaceholderForType):
return RequireTypeSubstitutes()
if isinstance(self.type_class3, typeclasses.Type3Class):
if self.type_class3 in typ.classes:
return None
else:
if self.type_class3 in self.__class__.DATA.get(typ.name, set()):
return None
if self.type_class3 in typ.classes:
return None
return Error(f'{typ.name} does not implement the {self.type_class3} type class')

View File

@ -132,16 +132,6 @@ def expression(ctx: Context, inp: ourlang.Expression) -> ConstraintGenerator:
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
return
if isinstance(inp, ourlang.Fold):
yield from expression(ctx, inp.base)
yield from expression(ctx, inp.iter)
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, inp.base.type3, inp.type3,
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
yield MustImplementTypeClassConstraint('Foldable', inp.iter.type3)
return
raise NotImplementedError(expression, inp)
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: