Removes the special casing for foldl
Has to implement both functions as arguments and type place holders (variables) for type constructors. Probably have to introduce a type for functions
This commit is contained in:
parent
ac4b46bbe7
commit
dd4b9373ac
@ -102,10 +102,6 @@ def expression(inp: ourlang.Expression) -> str:
|
|||||||
if isinstance(inp, ourlang.AccessStructMember):
|
if isinstance(inp, ourlang.AccessStructMember):
|
||||||
return f'{expression(inp.varref)}.{inp.member}'
|
return f'{expression(inp.varref)}.{inp.member}'
|
||||||
|
|
||||||
if isinstance(inp, ourlang.Fold):
|
|
||||||
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
|
|
||||||
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
|
|
||||||
|
|
||||||
raise NotImplementedError(expression, inp)
|
raise NotImplementedError(expression, inp)
|
||||||
|
|
||||||
def statement(inp: ourlang.Statement) -> Statements:
|
def statement(inp: ourlang.Statement) -> Statements:
|
||||||
|
|||||||
@ -4,7 +4,7 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
|
|||||||
import struct
|
import struct
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from . import codestyle, ourlang, prelude, wasm
|
from . import ourlang, prelude, wasm
|
||||||
from .runtime import calculate_alloc_size, calculate_member_offset
|
from .runtime import calculate_alloc_size, calculate_member_offset
|
||||||
from .stdlib import alloc as stdlib_alloc
|
from .stdlib import alloc as stdlib_alloc
|
||||||
from .stdlib import types as stdlib_types
|
from .stdlib import types as stdlib_types
|
||||||
@ -376,90 +376,8 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
|||||||
)))
|
)))
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.Fold):
|
|
||||||
expression_fold(wgn, inp)
|
|
||||||
return
|
|
||||||
|
|
||||||
raise NotImplementedError(expression, inp)
|
raise NotImplementedError(expression, inp)
|
||||||
|
|
||||||
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
|
|
||||||
"""
|
|
||||||
Compile: Fold expression
|
|
||||||
"""
|
|
||||||
assert inp.type3 is not None, TYPE3_ASSERTION_ERROR
|
|
||||||
|
|
||||||
if inp.iter.type3 is not prelude.bytes_:
|
|
||||||
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='acu :: u8')
|
|
||||||
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
|
|
||||||
wgn.add_statement('nop', comment='adr :: bytes*')
|
|
||||||
adr_var = wgn.temp_var_i32('fold_i32_adr')
|
|
||||||
wgn.add_statement('nop', comment='len :: i32')
|
|
||||||
len_var = wgn.temp_var_i32('fold_i32_len')
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='acu = base')
|
|
||||||
expression(wgn, inp.base)
|
|
||||||
wgn.local.set(acu_var)
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='adr = adr(iter)')
|
|
||||||
expression(wgn, inp.iter)
|
|
||||||
wgn.local.set(adr_var)
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='len = len(iter)')
|
|
||||||
wgn.local.get(adr_var)
|
|
||||||
wgn.i32.load()
|
|
||||||
wgn.local.set(len_var)
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='i = 0')
|
|
||||||
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
|
|
||||||
wgn.i32.const(0)
|
|
||||||
wgn.local.set(idx_var)
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='if i < len')
|
|
||||||
wgn.local.get(idx_var)
|
|
||||||
wgn.local.get(len_var)
|
|
||||||
wgn.i32.lt_u()
|
|
||||||
with wgn.if_():
|
|
||||||
# From here on, adr_var is the address of byte we're referencing
|
|
||||||
# This is akin to calling stdlib_types.__subscript_bytes__
|
|
||||||
# But since we already know we are inside of bounds,
|
|
||||||
# can just bypass it and load the memory directly.
|
|
||||||
wgn.local.get(adr_var)
|
|
||||||
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
|
|
||||||
wgn.i32.add()
|
|
||||||
wgn.local.set(adr_var)
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='while True')
|
|
||||||
with wgn.loop():
|
|
||||||
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
|
|
||||||
wgn.local.get(acu_var)
|
|
||||||
|
|
||||||
# Get the next byte, write back the address
|
|
||||||
wgn.local.get(adr_var)
|
|
||||||
wgn.i32.const(1)
|
|
||||||
wgn.i32.add()
|
|
||||||
wgn.local.tee(adr_var)
|
|
||||||
wgn.i32.load8_u()
|
|
||||||
|
|
||||||
wgn.add_statement('call', f'${inp.func.name}')
|
|
||||||
wgn.local.set(acu_var)
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='i = i + 1')
|
|
||||||
wgn.local.get(idx_var)
|
|
||||||
wgn.i32.const(1)
|
|
||||||
wgn.i32.add()
|
|
||||||
wgn.local.set(idx_var)
|
|
||||||
|
|
||||||
wgn.add_statement('nop', comment='if i >= len: break')
|
|
||||||
wgn.local.get(idx_var)
|
|
||||||
wgn.local.get(len_var)
|
|
||||||
wgn.i32.lt_u()
|
|
||||||
wgn.br_if(0)
|
|
||||||
|
|
||||||
# return acu
|
|
||||||
wgn.local.get(acu_var)
|
|
||||||
|
|
||||||
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
|
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
|
||||||
"""
|
"""
|
||||||
Compile: Return statement
|
Compile: Return statement
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Contains the syntax tree for ourlang
|
Contains the syntax tree for ourlang
|
||||||
"""
|
"""
|
||||||
import enum
|
|
||||||
from typing import Dict, Iterable, List, Optional, Union
|
from typing import Dict, Iterable, List, Optional, Union
|
||||||
|
|
||||||
from . import prelude
|
from . import prelude
|
||||||
@ -161,6 +160,18 @@ class FunctionCall(Expression):
|
|||||||
self.function = function
|
self.function = function
|
||||||
self.arguments = []
|
self.arguments = []
|
||||||
|
|
||||||
|
class FunctionReference(Expression):
|
||||||
|
"""
|
||||||
|
An function reference expression within a statement
|
||||||
|
"""
|
||||||
|
__slots__ = ('function', )
|
||||||
|
|
||||||
|
function: 'Function'
|
||||||
|
|
||||||
|
def __init__(self, function: 'Function') -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.function = function
|
||||||
|
|
||||||
class TupleInstantiation(Expression):
|
class TupleInstantiation(Expression):
|
||||||
"""
|
"""
|
||||||
Instantiation a tuple
|
Instantiation a tuple
|
||||||
@ -207,36 +218,6 @@ class AccessStructMember(Expression):
|
|||||||
self.struct_type3 = struct_type3
|
self.struct_type3 = struct_type3
|
||||||
self.member = member
|
self.member = member
|
||||||
|
|
||||||
class Fold(Expression):
|
|
||||||
"""
|
|
||||||
A (left or right) fold
|
|
||||||
"""
|
|
||||||
class Direction(enum.Enum):
|
|
||||||
"""
|
|
||||||
Which direction to fold in
|
|
||||||
"""
|
|
||||||
LEFT = 0
|
|
||||||
RIGHT = 1
|
|
||||||
|
|
||||||
dir: Direction
|
|
||||||
func: 'Function'
|
|
||||||
base: Expression
|
|
||||||
iter: Expression
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dir_: Direction,
|
|
||||||
func: 'Function',
|
|
||||||
base: Expression,
|
|
||||||
iter_: Expression,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.dir = dir_
|
|
||||||
self.func = func
|
|
||||||
self.base = base
|
|
||||||
self.iter = iter_
|
|
||||||
|
|
||||||
class Statement:
|
class Statement:
|
||||||
"""
|
"""
|
||||||
A statement within a function
|
A statement within a function
|
||||||
|
|||||||
@ -14,10 +14,10 @@ from .ourlang import (
|
|||||||
ConstantStruct,
|
ConstantStruct,
|
||||||
ConstantTuple,
|
ConstantTuple,
|
||||||
Expression,
|
Expression,
|
||||||
Fold,
|
|
||||||
Function,
|
Function,
|
||||||
FunctionCall,
|
FunctionCall,
|
||||||
FunctionParam,
|
FunctionParam,
|
||||||
|
FunctionReference,
|
||||||
Module,
|
Module,
|
||||||
ModuleConstantDef,
|
ModuleConstantDef,
|
||||||
ModuleDataBlock,
|
ModuleDataBlock,
|
||||||
@ -446,6 +446,9 @@ class OurVisitor:
|
|||||||
cdef = module.constant_defs[node.id]
|
cdef = module.constant_defs[node.id]
|
||||||
return VariableReference(cdef)
|
return VariableReference(cdef)
|
||||||
|
|
||||||
|
if node.id in module.functions:
|
||||||
|
return FunctionReference(module.functions[node.id])
|
||||||
|
|
||||||
_raise_static_error(node, f'Undefined variable {node.id}')
|
_raise_static_error(node, f'Undefined variable {node.id}')
|
||||||
|
|
||||||
if isinstance(node, ast.Tuple):
|
if isinstance(node, ast.Tuple):
|
||||||
@ -462,7 +465,7 @@ class OurVisitor:
|
|||||||
|
|
||||||
raise NotImplementedError(f'{node} as expr in FunctionDef')
|
raise NotImplementedError(f'{node} as expr in FunctionDef')
|
||||||
|
|
||||||
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]:
|
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
|
||||||
if node.keywords:
|
if node.keywords:
|
||||||
_raise_static_error(node, 'Keyword calling not supported') # Yet?
|
_raise_static_error(node, 'Keyword calling not supported') # Yet?
|
||||||
|
|
||||||
@ -475,28 +478,6 @@ class OurVisitor:
|
|||||||
|
|
||||||
if node.func.id in PRELUDE_METHODS:
|
if node.func.id in PRELUDE_METHODS:
|
||||||
func = PRELUDE_METHODS[node.func.id]
|
func = PRELUDE_METHODS[node.func.id]
|
||||||
elif node.func.id == 'foldl':
|
|
||||||
if 3 != len(node.args):
|
|
||||||
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
|
|
||||||
|
|
||||||
# TODO: This is not generic, you cannot return a function
|
|
||||||
subnode = node.args[0]
|
|
||||||
if not isinstance(subnode, ast.Name):
|
|
||||||
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
|
|
||||||
if not isinstance(subnode.ctx, ast.Load):
|
|
||||||
_raise_static_error(subnode, 'Must be load context')
|
|
||||||
if subnode.id not in module.functions:
|
|
||||||
_raise_static_error(subnode, 'Reference to undefined function')
|
|
||||||
func = module.functions[subnode.id]
|
|
||||||
if 2 != len(func.posonlyargs):
|
|
||||||
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
|
|
||||||
|
|
||||||
return Fold(
|
|
||||||
Fold.Direction.LEFT,
|
|
||||||
func,
|
|
||||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
|
|
||||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
if node.func.id not in module.functions:
|
if node.func.id not in module.functions:
|
||||||
_raise_static_error(node, 'Call to undefined function')
|
_raise_static_error(node, 'Call to undefined function')
|
||||||
|
|||||||
@ -574,12 +574,16 @@ instance_type_class(Promotable, f32, f64, methods={
|
|||||||
|
|
||||||
Foldable = Type3Class('Foldable', (t, ), methods={
|
Foldable = Type3Class('Foldable', (t, ), methods={
|
||||||
'sum': [t(a), a],
|
'sum': [t(a), a],
|
||||||
|
'foldl': [[a, b, b], b, t(a), b],
|
||||||
|
'foldr': [[a, b, b], b, t(a), b],
|
||||||
}, operators={}, additional_context={
|
}, operators={}, additional_context={
|
||||||
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
|
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
|
||||||
})
|
})
|
||||||
|
|
||||||
instance_type_class(Foldable, static_array, methods={
|
instance_type_class(Foldable, static_array, methods={
|
||||||
'sum': stdtypes.static_array_sum,
|
'sum': stdtypes.static_array_sum,
|
||||||
|
'foldl': stdtypes.static_array_foldl,
|
||||||
|
'foldr': stdtypes.static_array_foldr,
|
||||||
})
|
})
|
||||||
|
|
||||||
PRELUDE_TYPE_CLASSES = {
|
PRELUDE_TYPE_CLASSES = {
|
||||||
|
|||||||
@ -1185,3 +1185,9 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
|
|||||||
|
|
||||||
g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]')
|
g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]')
|
||||||
# End result: [sum]
|
# End result: [sum]
|
||||||
|
|
||||||
|
def static_array_foldl(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||||
|
raise NotImplementedError(tv_map)
|
||||||
|
|
||||||
|
def static_array_foldr(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||||
|
raise NotImplementedError(tv_map)
|
||||||
|
|||||||
@ -6,6 +6,7 @@ These need to be resolved before the program can be compiled.
|
|||||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
from .. import ourlang, prelude
|
from .. import ourlang, prelude
|
||||||
|
from .functions import FunctionArgument, TypeVariable
|
||||||
from .placeholders import PlaceholderForType, Type3OrPlaceholder
|
from .placeholders import PlaceholderForType, Type3OrPlaceholder
|
||||||
from .routers import NoRouteForTypeException, TypeApplicationRouter
|
from .routers import NoRouteForTypeException, TypeApplicationRouter
|
||||||
from .typeclasses import Type3Class
|
from .typeclasses import Type3Class
|
||||||
@ -110,19 +111,19 @@ class SameTypeConstraint(ConstraintBase):
|
|||||||
"""
|
"""
|
||||||
__slots__ = ('type_list', )
|
__slots__ = ('type_list', )
|
||||||
|
|
||||||
type_list: List[Type3OrPlaceholder]
|
type_list: List[Type3OrPlaceholder | ourlang.Function]
|
||||||
|
|
||||||
def __init__(self, *type_list: Type3OrPlaceholder, comment: Optional[str] = None) -> None:
|
def __init__(self, *type_list: Type3OrPlaceholder | ourlang.Function, comment: Optional[str] = None) -> None:
|
||||||
super().__init__(comment=comment)
|
super().__init__(comment=comment)
|
||||||
|
|
||||||
assert len(type_list) > 1
|
assert len(type_list) > 1
|
||||||
self.type_list = [*type_list]
|
self.type_list = [*type_list]
|
||||||
|
|
||||||
def check(self) -> CheckResult:
|
def check(self) -> CheckResult:
|
||||||
known_types: List[Type3] = []
|
known_types: List[Type3 | ourlang.Function] = []
|
||||||
phft_list = []
|
phft_list = []
|
||||||
for typ in self.type_list:
|
for typ in self.type_list:
|
||||||
if isinstance(typ, Type3):
|
if isinstance(typ, (Type3, ourlang.Function, )):
|
||||||
known_types.append(typ)
|
known_types.append(typ)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -158,7 +159,7 @@ class SameTypeConstraint(ConstraintBase):
|
|||||||
return (
|
return (
|
||||||
' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))),
|
' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))),
|
||||||
{
|
{
|
||||||
't' + str(idx): typ
|
't' + str(idx): typ.name if isinstance(typ, ourlang.Function) else typ
|
||||||
for idx, typ in enumerate(self.type_list)
|
for idx, typ in enumerate(self.type_list)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -208,6 +209,50 @@ class SameTypeArgumentConstraint(ConstraintBase):
|
|||||||
|
|
||||||
raise NotImplementedError(tc_typ, arg_typ)
|
raise NotImplementedError(tc_typ, arg_typ)
|
||||||
|
|
||||||
|
def human_readable(self) -> HumanReadableRet:
|
||||||
|
return (
|
||||||
|
'{tc_var}` == {arg_var}',
|
||||||
|
{
|
||||||
|
'tc_var': self.tc_var if self.tc_var.resolve_as is None else self.tc_var,
|
||||||
|
'arg_var': self.arg_var if self.arg_var.resolve_as is None else self.arg_var,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
class SameFunctionArgumentConstraint(ConstraintBase):
|
||||||
|
__slots__ = ('type3', 'func_arg', 'type_var_map', )
|
||||||
|
|
||||||
|
type3: PlaceholderForType
|
||||||
|
func_arg: FunctionArgument
|
||||||
|
type_var_map: dict[TypeVariable, PlaceholderForType]
|
||||||
|
|
||||||
|
def __init__(self, type3: PlaceholderForType, func_arg: FunctionArgument, type_var_map: dict[TypeVariable, PlaceholderForType], *, comment: str) -> None:
|
||||||
|
super().__init__(comment=comment)
|
||||||
|
|
||||||
|
self.type3 = type3
|
||||||
|
self.func_arg = func_arg
|
||||||
|
self.type_var_map = type_var_map
|
||||||
|
|
||||||
|
def check(self) -> CheckResult:
|
||||||
|
print([
|
||||||
|
(x, y.resolve_as, )
|
||||||
|
for x, y in self.type_var_map.items()
|
||||||
|
])
|
||||||
|
|
||||||
|
if self.type3.resolve_as is None:
|
||||||
|
return RequireTypeSubstitutes()
|
||||||
|
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def human_readable(self) -> HumanReadableRet:
|
||||||
|
return (
|
||||||
|
'{type3} / {func_arg} / {type_var_map}',
|
||||||
|
{
|
||||||
|
'type3': self.type3,
|
||||||
|
'func_arg': self.func_arg.name,
|
||||||
|
'type_var_map': str(self.type_var_map),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
class TupleMatchConstraint(ConstraintBase):
|
class TupleMatchConstraint(ConstraintBase):
|
||||||
__slots__ = ('exp_type', 'args', )
|
__slots__ = ('exp_type', 'args', )
|
||||||
|
|
||||||
@ -264,14 +309,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
|
|||||||
__slots__ = ('context', 'type_class3', 'types', )
|
__slots__ = ('context', 'type_class3', 'types', )
|
||||||
|
|
||||||
context: Context
|
context: Context
|
||||||
type_class3: Union[str, Type3Class]
|
type_class3: Type3Class
|
||||||
types: list[Type3OrPlaceholder]
|
types: list[Type3OrPlaceholder]
|
||||||
|
|
||||||
DATA = {
|
def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
|
||||||
'bytes': {'Foldable'},
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
|
|
||||||
super().__init__(comment=comment)
|
super().__init__(comment=comment)
|
||||||
|
|
||||||
self.context = context
|
self.context = context
|
||||||
@ -299,13 +340,9 @@ class MustImplementTypeClassConstraint(ConstraintBase):
|
|||||||
|
|
||||||
assert len(typ_list) == len(self.types)
|
assert len(typ_list) == len(self.types)
|
||||||
|
|
||||||
if isinstance(self.type_class3, Type3Class):
|
key = (self.type_class3, tuple(typ_list), )
|
||||||
key = (self.type_class3, tuple(typ_list), )
|
if key in self.context.type_class_instances_existing:
|
||||||
if key in self.context.type_class_instances_existing:
|
return None
|
||||||
return None
|
|
||||||
else:
|
|
||||||
if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()):
|
|
||||||
return None
|
|
||||||
|
|
||||||
typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name
|
typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name
|
||||||
typ_name_list = ' '.join(x.name for x in typ_list)
|
typ_name_list = ' '.join(x.name for x in typ_list)
|
||||||
|
|||||||
@ -12,12 +12,14 @@ from .constraints import (
|
|||||||
Context,
|
Context,
|
||||||
LiteralFitsConstraint,
|
LiteralFitsConstraint,
|
||||||
MustImplementTypeClassConstraint,
|
MustImplementTypeClassConstraint,
|
||||||
|
SameFunctionArgumentConstraint,
|
||||||
SameTypeArgumentConstraint,
|
SameTypeArgumentConstraint,
|
||||||
SameTypeConstraint,
|
SameTypeConstraint,
|
||||||
TupleMatchConstraint,
|
TupleMatchConstraint,
|
||||||
)
|
)
|
||||||
from .functions import (
|
from .functions import (
|
||||||
Constraint_TypeClassInstanceExists,
|
Constraint_TypeClassInstanceExists,
|
||||||
|
FunctionArgument,
|
||||||
FunctionSignature,
|
FunctionSignature,
|
||||||
TypeVariable,
|
TypeVariable,
|
||||||
TypeVariableApplication_Unary,
|
TypeVariableApplication_Unary,
|
||||||
@ -111,6 +113,33 @@ def _expression_function_call(
|
|||||||
|
|
||||||
raise NotImplementedError(constraint)
|
raise NotImplementedError(constraint)
|
||||||
|
|
||||||
|
func_var_map = {
|
||||||
|
x: PlaceholderForType([])
|
||||||
|
for x in signature.args
|
||||||
|
if isinstance(x, FunctionArgument)
|
||||||
|
}
|
||||||
|
|
||||||
|
# If some of the function arguments are functions,
|
||||||
|
# we need to deal with those separately.
|
||||||
|
for sig_arg in signature.args:
|
||||||
|
if not isinstance(sig_arg, FunctionArgument):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure that for all type variables in the function
|
||||||
|
# there are also type variables available
|
||||||
|
for func_arg in sig_arg.args:
|
||||||
|
if isinstance(func_arg, Type3):
|
||||||
|
continue
|
||||||
|
|
||||||
|
type_var_map.setdefault(func_arg, PlaceholderForType([]))
|
||||||
|
|
||||||
|
yield SameFunctionArgumentConstraint(
|
||||||
|
func_var_map[sig_arg],
|
||||||
|
sig_arg,
|
||||||
|
type_var_map,
|
||||||
|
comment=f'Ensure `{sig_arg.name}` matches in {signature}',
|
||||||
|
)
|
||||||
|
|
||||||
# If some of the function arguments are type constructors,
|
# If some of the function arguments are type constructors,
|
||||||
# we need to deal with those separately.
|
# we need to deal with those separately.
|
||||||
# That is, given `foo :: t a -> a` we need to ensure
|
# That is, given `foo :: t a -> a` we need to ensure
|
||||||
@ -120,6 +149,9 @@ def _expression_function_call(
|
|||||||
# Not a type variable at all
|
# Not a type variable at all
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(sig_arg, FunctionArgument):
|
||||||
|
continue
|
||||||
|
|
||||||
if sig_arg.application.constructor is None:
|
if sig_arg.application.constructor is None:
|
||||||
# Not a type variable for a type constructor
|
# Not a type variable for a type constructor
|
||||||
continue
|
continue
|
||||||
@ -150,9 +182,20 @@ def _expression_function_call(
|
|||||||
yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment)
|
yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(sig_part, FunctionArgument):
|
||||||
|
yield SameTypeConstraint(func_var_map[sig_part], arg_placeholders[arg_expr], comment=comment)
|
||||||
|
continue
|
||||||
|
|
||||||
raise NotImplementedError(sig_part)
|
raise NotImplementedError(sig_part)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
|
||||||
|
yield SameTypeConstraint(
|
||||||
|
inp.function,
|
||||||
|
phft,
|
||||||
|
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
|
||||||
|
)
|
||||||
|
|
||||||
def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator:
|
def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator:
|
||||||
if isinstance(inp, ourlang.Constant):
|
if isinstance(inp, ourlang.Constant):
|
||||||
yield from constant(ctx, inp, phft)
|
yield from constant(ctx, inp, phft)
|
||||||
@ -171,6 +214,10 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
|
|||||||
yield from expression_function_call(ctx, inp, phft)
|
yield from expression_function_call(ctx, inp, phft)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if isinstance(inp, ourlang.FunctionReference):
|
||||||
|
yield from expression_function_reference(ctx, inp, phft)
|
||||||
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.TupleInstantiation):
|
if isinstance(inp, ourlang.TupleInstantiation):
|
||||||
r_type = []
|
r_type = []
|
||||||
for arg in inp.elements:
|
for arg in inp.elements:
|
||||||
@ -209,19 +256,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
|
|||||||
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
|
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
|
||||||
return
|
return
|
||||||
|
|
||||||
if isinstance(inp, ourlang.Fold):
|
|
||||||
base_phft = PlaceholderForType([inp.base])
|
|
||||||
iter_phft = PlaceholderForType([inp.iter])
|
|
||||||
|
|
||||||
yield from expression(ctx, inp.base, base_phft)
|
|
||||||
yield from expression(ctx, inp.iter, iter_phft)
|
|
||||||
|
|
||||||
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft,
|
|
||||||
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
|
|
||||||
yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft])
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
raise NotImplementedError(expression, inp)
|
raise NotImplementedError(expression, inp)
|
||||||
|
|
||||||
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:
|
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List, Union
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .typeclasses import Type3Class
|
from .typeclasses import Type3Class
|
||||||
@ -155,15 +157,29 @@ class TypeVariableContext:
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f'TypeVariableContext({self.constraints!r})'
|
return f'TypeVariableContext({self.constraints!r})'
|
||||||
|
|
||||||
|
class FunctionArgument:
|
||||||
|
__slots__ = ('args', 'name', )
|
||||||
|
|
||||||
|
args: list[Type3 | TypeVariable]
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def __init__(self, args: list[Type3 | TypeVariable]) -> None:
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
self.name = '(' + ' -> '.join(x.name for x in args) + ')'
|
||||||
|
|
||||||
class FunctionSignature:
|
class FunctionSignature:
|
||||||
__slots__ = ('context', 'args', )
|
__slots__ = ('context', 'args', )
|
||||||
|
|
||||||
context: TypeVariableContext
|
context: TypeVariableContext
|
||||||
args: List[Union['Type3', TypeVariable]]
|
args: List[Type3 | TypeVariable | FunctionArgument]
|
||||||
|
|
||||||
def __init__(self, context: TypeVariableContext, args: Iterable[Union['Type3', TypeVariable]]) -> None:
|
def __init__(self, context: TypeVariableContext, args: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]) -> None:
|
||||||
self.context = context.__copy__()
|
self.context = context.__copy__()
|
||||||
self.args = list(args)
|
self.args = list(
|
||||||
|
FunctionArgument(x) if isinstance(x, list) else x
|
||||||
|
for x in args
|
||||||
|
)
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return str(self.context) + ' -> '.join(x.name for x in self.args)
|
return str(self.context) + ' -> '.join(x.name for x in self.args)
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Iterable, List, Mapping, Optional, Union
|
from typing import Dict, Iterable, List, Mapping, Optional
|
||||||
|
|
||||||
from .functions import (
|
from .functions import (
|
||||||
Constraint_TypeClassInstanceExists,
|
Constraint_TypeClassInstanceExists,
|
||||||
@ -42,8 +42,8 @@ class Type3Class:
|
|||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
args: Type3ClassArgs,
|
args: Type3ClassArgs,
|
||||||
methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]],
|
methods: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
|
||||||
operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]],
|
operators: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
|
||||||
inherited_classes: Optional[List['Type3Class']] = None,
|
inherited_classes: Optional[List['Type3Class']] = None,
|
||||||
additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None,
|
additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -71,19 +71,23 @@ class Type3Class:
|
|||||||
return self.name
|
return self.name
|
||||||
|
|
||||||
def _create_signature(
|
def _create_signature(
|
||||||
method_arg_list: Iterable[Type3 | TypeVariable],
|
method_arg_list: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]],
|
||||||
type_class3: Type3Class,
|
type_class3: Type3Class,
|
||||||
) -> FunctionSignature:
|
) -> FunctionSignature:
|
||||||
context = TypeVariableContext()
|
context = TypeVariableContext()
|
||||||
if not isinstance(type_class3.args[0], TypeConstructorVariable):
|
if not isinstance(type_class3.args[0], TypeConstructorVariable):
|
||||||
context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args))
|
context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args))
|
||||||
|
|
||||||
signature_args: list[Type3 | TypeVariable] = []
|
signature_args: list[Type3 | TypeVariable | list[Type3 | TypeVariable]] = []
|
||||||
for method_arg in method_arg_list:
|
for method_arg in method_arg_list:
|
||||||
if isinstance(method_arg, Type3):
|
if isinstance(method_arg, Type3):
|
||||||
signature_args.append(method_arg)
|
signature_args.append(method_arg)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if isinstance(method_arg, list):
|
||||||
|
signature_args.append(method_arg)
|
||||||
|
continue
|
||||||
|
|
||||||
if isinstance(method_arg, TypeVariable):
|
if isinstance(method_arg, TypeVariable):
|
||||||
type_constructor = method_arg.application.constructor
|
type_constructor = method_arg.application.constructor
|
||||||
if type_constructor is None:
|
if type_constructor is None:
|
||||||
|
|||||||
@ -1,61 +0,0 @@
|
|||||||
import pytest
|
|
||||||
|
|
||||||
from ..helpers import Suite
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.integration_test
|
|
||||||
def test_foldl_1():
|
|
||||||
code_py = """
|
|
||||||
def u8_or(l: u8, r: u8) -> u8:
|
|
||||||
return l | r
|
|
||||||
|
|
||||||
@exported
|
|
||||||
def testEntry(b: bytes) -> u8:
|
|
||||||
return foldl(u8_or, 128, b)
|
|
||||||
"""
|
|
||||||
suite = Suite(code_py)
|
|
||||||
|
|
||||||
result = suite.run_code(b'')
|
|
||||||
assert 128 == result.returned_value
|
|
||||||
|
|
||||||
result = suite.run_code(b'\x80')
|
|
||||||
assert 128 == result.returned_value
|
|
||||||
|
|
||||||
result = suite.run_code(b'\x80\x40')
|
|
||||||
assert 192 == result.returned_value
|
|
||||||
|
|
||||||
result = suite.run_code(b'\x80\x40\x20\x10')
|
|
||||||
assert 240 == result.returned_value
|
|
||||||
|
|
||||||
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
|
|
||||||
assert 255 == result.returned_value
|
|
||||||
|
|
||||||
@pytest.mark.integration_test
|
|
||||||
def test_foldl_2():
|
|
||||||
code_py = """
|
|
||||||
def xor(l: u8, r: u8) -> u8:
|
|
||||||
return l ^ r
|
|
||||||
|
|
||||||
@exported
|
|
||||||
def testEntry(a: bytes, b: bytes) -> u8:
|
|
||||||
return foldl(xor, 0, a) ^ foldl(xor, 0, b)
|
|
||||||
"""
|
|
||||||
suite = Suite(code_py)
|
|
||||||
|
|
||||||
result = suite.run_code(b'\x55\x0F', b'\x33\x80')
|
|
||||||
assert 233 == result.returned_value
|
|
||||||
|
|
||||||
@pytest.mark.integration_test
|
|
||||||
def test_foldl_3():
|
|
||||||
code_py = """
|
|
||||||
def xor(l: u32, r: u8) -> u32:
|
|
||||||
return l ^ extend(r)
|
|
||||||
|
|
||||||
@exported
|
|
||||||
def testEntry(a: bytes) -> u32:
|
|
||||||
return foldl(xor, 0, a)
|
|
||||||
"""
|
|
||||||
suite = Suite(code_py)
|
|
||||||
|
|
||||||
result = suite.run_code(b'\x55\x0F\x33\x80')
|
|
||||||
assert 233 == result.returned_value
|
|
||||||
@ -36,6 +36,52 @@ def testEntry(x: Foo[4]) -> Foo:
|
|||||||
with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'):
|
with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'):
|
||||||
Suite(code_py).run_code()
|
Suite(code_py).run_code()
|
||||||
|
|
||||||
|
@pytest.mark.integration_test
|
||||||
|
def test_foldable_foldl_size():
|
||||||
|
code_py = """
|
||||||
|
def u8_or(l: u8, r: u8) -> u8:
|
||||||
|
return l | r
|
||||||
|
|
||||||
|
@exported
|
||||||
|
def testEntry(b: bytes) -> u8:
|
||||||
|
return foldl(u8_or, 128, b)
|
||||||
|
"""
|
||||||
|
suite = Suite(code_py)
|
||||||
|
|
||||||
|
result = suite.run_code(b'')
|
||||||
|
assert 128 == result.returned_value
|
||||||
|
|
||||||
|
result = suite.run_code(b'\x80')
|
||||||
|
assert 128 == result.returned_value
|
||||||
|
|
||||||
|
result = suite.run_code(b'\x80\x40')
|
||||||
|
assert 192 == result.returned_value
|
||||||
|
|
||||||
|
result = suite.run_code(b'\x80\x40\x20\x10')
|
||||||
|
assert 240 == result.returned_value
|
||||||
|
|
||||||
|
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
|
||||||
|
assert 255 == result.returned_value
|
||||||
|
|
||||||
|
@pytest.mark.integration_test
|
||||||
|
@pytest.mark.parametrize('direction, exp_result', [
|
||||||
|
('foldl', -55, ),
|
||||||
|
('foldr', -5, ),
|
||||||
|
])
|
||||||
|
def test_foldable_foldl_foldr(direction, exp_result):
|
||||||
|
# See https://stackoverflow.com/a/13280185
|
||||||
|
code_py = f"""
|
||||||
|
def i32_sub(l: i32, r: i32) -> i32:
|
||||||
|
return l - r
|
||||||
|
|
||||||
|
@exported
|
||||||
|
def testEntry(b: i32[10]) -> i32:
|
||||||
|
return {direction}(i32_sub, 0, b)
|
||||||
|
"""
|
||||||
|
suite = Suite(code_py)
|
||||||
|
|
||||||
|
result = suite.run_code(tuple(range(1, 10)))
|
||||||
|
assert exp_result == result.returned_value
|
||||||
|
|
||||||
@pytest.mark.integration_test
|
@pytest.mark.integration_test
|
||||||
def test_foldable_invalid_return_type():
|
def test_foldable_invalid_return_type():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user