Removes the special casing for foldl
Had to implement both functions as arguments and type place holders (variables) for type constructors. Had to implement functions as a type as well. Still have to figure out how to pass functions around.
This commit is contained in:
parent
ac4b46bbe7
commit
bfb3d2b3a0
2
TODO.md
2
TODO.md
@ -30,3 +30,5 @@
|
||||
- Functions don't seem to be a thing on typing level yet?
|
||||
- Related to the FIXME in phasm_type3?
|
||||
- Type constuctor should also be able to constuct placeholders - somehow.
|
||||
|
||||
- Read https://bytecodealliance.org/articles/multi-value-all-the-wasm
|
||||
|
||||
@ -85,6 +85,9 @@ def expression(inp: ourlang.Expression) -> str:
|
||||
|
||||
return f'{inp.function.name}({args})'
|
||||
|
||||
if isinstance(inp, ourlang.FunctionReference):
|
||||
return str(inp.function.name)
|
||||
|
||||
if isinstance(inp, ourlang.TupleInstantiation):
|
||||
args = ', '.join(
|
||||
expression(arg)
|
||||
@ -102,10 +105,6 @@ def expression(inp: ourlang.Expression) -> str:
|
||||
if isinstance(inp, ourlang.AccessStructMember):
|
||||
return f'{expression(inp.varref)}.{inp.member}'
|
||||
|
||||
if isinstance(inp, ourlang.Fold):
|
||||
fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr'
|
||||
return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})'
|
||||
|
||||
raise NotImplementedError(expression, inp)
|
||||
|
||||
def statement(inp: ourlang.Statement) -> Statements:
|
||||
|
||||
@ -4,11 +4,11 @@ This module contains the code to convert parsed Ourlang into WebAssembly code
|
||||
import struct
|
||||
from typing import List, Optional
|
||||
|
||||
from . import codestyle, ourlang, prelude, wasm
|
||||
from . import ourlang, prelude, wasm
|
||||
from .runtime import calculate_alloc_size, calculate_member_offset
|
||||
from .stdlib import alloc as stdlib_alloc
|
||||
from .stdlib import types as stdlib_types
|
||||
from .type3.functions import TypeVariable
|
||||
from .type3.functions import FunctionArgument, TypeVariable
|
||||
from .type3.routers import NoRouteForTypeException, TypeApplicationRouter
|
||||
from .type3.typeclasses import Type3ClassMethod
|
||||
from .type3.types import (
|
||||
@ -100,7 +100,7 @@ def type3(inp: Type3) -> wasm.WasmType:
|
||||
|
||||
raise NotImplementedError(type3, inp)
|
||||
|
||||
def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> None:
|
||||
def tuple_instantiation(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.TupleInstantiation) -> None:
|
||||
"""
|
||||
Compile: Instantiation (allocation) of a tuple
|
||||
"""
|
||||
@ -150,7 +150,7 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) ->
|
||||
|
||||
wgn.add_statement('nop', comment='PRE')
|
||||
wgn.local.get(tmp_var)
|
||||
expression(wgn, element)
|
||||
expression(wgn, mod, element)
|
||||
wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset))
|
||||
wgn.add_statement('nop', comment='POST')
|
||||
|
||||
@ -160,29 +160,29 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) ->
|
||||
wgn.local.get(tmp_var)
|
||||
|
||||
def expression_subscript_bytes(
|
||||
attrs: tuple[WasmGenerator, ourlang.Subscript],
|
||||
attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
|
||||
) -> None:
|
||||
wgn, inp = attrs
|
||||
wgn, mod, inp = attrs
|
||||
|
||||
expression(wgn, inp.varref)
|
||||
expression(wgn, inp.index)
|
||||
expression(wgn, mod, inp.varref)
|
||||
expression(wgn, mod, inp.index)
|
||||
wgn.call(stdlib_types.__subscript_bytes__)
|
||||
|
||||
def expression_subscript_static_array(
|
||||
attrs: tuple[WasmGenerator, ourlang.Subscript],
|
||||
attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
|
||||
args: tuple[Type3, IntType3],
|
||||
) -> None:
|
||||
wgn, inp = attrs
|
||||
wgn, mod, inp = attrs
|
||||
|
||||
el_type, el_len = args
|
||||
|
||||
# OPTIMIZE: If index is a constant, we can use offset instead of multiply
|
||||
# and we don't need to do the out of bounds check
|
||||
|
||||
expression(wgn, inp.varref)
|
||||
expression(wgn, mod, inp.varref)
|
||||
|
||||
tmp_var = wgn.temp_var_i32('index')
|
||||
expression(wgn, inp.index)
|
||||
expression(wgn, mod, inp.index)
|
||||
wgn.local.tee(tmp_var)
|
||||
|
||||
# Out of bounds check based on el_len.value
|
||||
@ -201,10 +201,10 @@ def expression_subscript_static_array(
|
||||
wgn.add_statement(f'{mtyp}.load')
|
||||
|
||||
def expression_subscript_tuple(
|
||||
attrs: tuple[WasmGenerator, ourlang.Subscript],
|
||||
attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript],
|
||||
args: tuple[Type3, ...],
|
||||
) -> None:
|
||||
wgn, inp = attrs
|
||||
wgn, mod, inp = attrs
|
||||
|
||||
assert isinstance(inp.index, ourlang.ConstantPrimitive)
|
||||
assert isinstance(inp.index.value, int)
|
||||
@ -217,7 +217,7 @@ def expression_subscript_tuple(
|
||||
el_type = args[inp.index.value]
|
||||
assert el_type is not None, TYPE3_ASSERTION_ERROR
|
||||
|
||||
expression(wgn, inp.varref)
|
||||
expression(wgn, mod, inp.varref)
|
||||
|
||||
if (prelude.InternalPassAsPointer, (el_type, )) in prelude.PRELUDE_TYPE_CLASS_INSTANCES_EXISTING:
|
||||
mtyp = 'i32'
|
||||
@ -226,12 +226,12 @@ def expression_subscript_tuple(
|
||||
|
||||
wgn.add_statement(f'{mtyp}.load', f'offset={offset}')
|
||||
|
||||
SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Subscript], None]()
|
||||
SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], None]()
|
||||
SUBSCRIPT_ROUTER.add_n(prelude.bytes_, expression_subscript_bytes)
|
||||
SUBSCRIPT_ROUTER.add(prelude.static_array, expression_subscript_static_array)
|
||||
SUBSCRIPT_ROUTER.add(prelude.tuple_, expression_subscript_tuple)
|
||||
|
||||
def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) -> None:
|
||||
"""
|
||||
Compile: Any expression
|
||||
"""
|
||||
@ -291,14 +291,14 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
wgn.i32.const(address)
|
||||
return
|
||||
|
||||
expression(wgn, inp.variable.constant)
|
||||
expression(wgn, mod, inp.variable.constant)
|
||||
return
|
||||
|
||||
raise NotImplementedError(expression, inp.variable)
|
||||
|
||||
if isinstance(inp, ourlang.BinaryOp):
|
||||
expression(wgn, inp.left)
|
||||
expression(wgn, inp.right)
|
||||
expression(wgn, mod, inp.left)
|
||||
expression(wgn, mod, inp.right)
|
||||
|
||||
type_var_map: dict[TypeVariable, Type3] = {}
|
||||
|
||||
@ -313,6 +313,10 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
type_var_map[type_var] = arg_expr.type3
|
||||
continue
|
||||
|
||||
if isinstance(type_var, FunctionArgument):
|
||||
# Fixed type, not part of the lookup requirements
|
||||
continue
|
||||
|
||||
raise NotImplementedError(type_var, arg_expr.type3)
|
||||
|
||||
router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.operator]
|
||||
@ -321,7 +325,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
|
||||
if isinstance(inp, ourlang.FunctionCall):
|
||||
for arg in inp.arguments:
|
||||
expression(wgn, arg)
|
||||
expression(wgn, mod, arg)
|
||||
|
||||
if isinstance(inp.function, Type3ClassMethod):
|
||||
# FIXME: Duplicate code with BinaryOp
|
||||
@ -338,6 +342,10 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
type_var_map[type_var] = arg_expr.type3
|
||||
continue
|
||||
|
||||
if isinstance(type_var, FunctionArgument):
|
||||
# Fixed type, not part of the lookup requirements
|
||||
continue
|
||||
|
||||
raise NotImplementedError(type_var, arg_expr.type3)
|
||||
|
||||
router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.function]
|
||||
@ -350,15 +358,24 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
wgn.add_statement('call', '${}'.format(inp.function.name))
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.FunctionReference):
|
||||
idx = mod.functions_table.get(inp.function)
|
||||
if idx is None:
|
||||
idx = len(mod.functions_table)
|
||||
mod.functions_table[inp.function] = idx
|
||||
|
||||
wgn.add_statement('i32.const', str(idx), comment=inp.function.name)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.TupleInstantiation):
|
||||
tuple_instantiation(wgn, inp)
|
||||
tuple_instantiation(wgn, mod, inp)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.Subscript):
|
||||
assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR
|
||||
|
||||
# Type checker guarantees we don't get routing errors
|
||||
SUBSCRIPT_ROUTER((wgn, inp, ), inp.varref.type3)
|
||||
SUBSCRIPT_ROUTER((wgn, mod, inp, ), inp.varref.type3)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.AccessStructMember):
|
||||
@ -370,111 +387,29 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None:
|
||||
|
||||
mtyp = LOAD_STORE_TYPE_MAP[member_type.name]
|
||||
|
||||
expression(wgn, inp.varref)
|
||||
expression(wgn, mod, inp.varref)
|
||||
wgn.add_statement(f'{mtyp}.load', 'offset=' + str(calculate_member_offset(
|
||||
inp.struct_type3.name, inp.struct_type3.application.arguments, inp.member
|
||||
)))
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.Fold):
|
||||
expression_fold(wgn, inp)
|
||||
return
|
||||
|
||||
raise NotImplementedError(expression, inp)
|
||||
|
||||
def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None:
|
||||
"""
|
||||
Compile: Fold expression
|
||||
"""
|
||||
assert inp.type3 is not None, TYPE3_ASSERTION_ERROR
|
||||
|
||||
if inp.iter.type3 is not prelude.bytes_:
|
||||
raise NotImplementedError(expression_fold, inp, inp.iter.type3)
|
||||
|
||||
wgn.add_statement('nop', comment='acu :: u8')
|
||||
acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu')
|
||||
wgn.add_statement('nop', comment='adr :: bytes*')
|
||||
adr_var = wgn.temp_var_i32('fold_i32_adr')
|
||||
wgn.add_statement('nop', comment='len :: i32')
|
||||
len_var = wgn.temp_var_i32('fold_i32_len')
|
||||
|
||||
wgn.add_statement('nop', comment='acu = base')
|
||||
expression(wgn, inp.base)
|
||||
wgn.local.set(acu_var)
|
||||
|
||||
wgn.add_statement('nop', comment='adr = adr(iter)')
|
||||
expression(wgn, inp.iter)
|
||||
wgn.local.set(adr_var)
|
||||
|
||||
wgn.add_statement('nop', comment='len = len(iter)')
|
||||
wgn.local.get(adr_var)
|
||||
wgn.i32.load()
|
||||
wgn.local.set(len_var)
|
||||
|
||||
wgn.add_statement('nop', comment='i = 0')
|
||||
idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx')
|
||||
wgn.i32.const(0)
|
||||
wgn.local.set(idx_var)
|
||||
|
||||
wgn.add_statement('nop', comment='if i < len')
|
||||
wgn.local.get(idx_var)
|
||||
wgn.local.get(len_var)
|
||||
wgn.i32.lt_u()
|
||||
with wgn.if_():
|
||||
# From here on, adr_var is the address of byte we're referencing
|
||||
# This is akin to calling stdlib_types.__subscript_bytes__
|
||||
# But since we already know we are inside of bounds,
|
||||
# can just bypass it and load the memory directly.
|
||||
wgn.local.get(adr_var)
|
||||
wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop
|
||||
wgn.i32.add()
|
||||
wgn.local.set(adr_var)
|
||||
|
||||
wgn.add_statement('nop', comment='while True')
|
||||
with wgn.loop():
|
||||
wgn.add_statement('nop', comment='acu = func(acu, iter[i])')
|
||||
wgn.local.get(acu_var)
|
||||
|
||||
# Get the next byte, write back the address
|
||||
wgn.local.get(adr_var)
|
||||
wgn.i32.const(1)
|
||||
wgn.i32.add()
|
||||
wgn.local.tee(adr_var)
|
||||
wgn.i32.load8_u()
|
||||
|
||||
wgn.add_statement('call', f'${inp.func.name}')
|
||||
wgn.local.set(acu_var)
|
||||
|
||||
wgn.add_statement('nop', comment='i = i + 1')
|
||||
wgn.local.get(idx_var)
|
||||
wgn.i32.const(1)
|
||||
wgn.i32.add()
|
||||
wgn.local.set(idx_var)
|
||||
|
||||
wgn.add_statement('nop', comment='if i >= len: break')
|
||||
wgn.local.get(idx_var)
|
||||
wgn.local.get(len_var)
|
||||
wgn.i32.lt_u()
|
||||
wgn.br_if(0)
|
||||
|
||||
# return acu
|
||||
wgn.local.get(acu_var)
|
||||
|
||||
def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None:
|
||||
def statement_return(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementReturn) -> None:
|
||||
"""
|
||||
Compile: Return statement
|
||||
"""
|
||||
expression(wgn, inp.value)
|
||||
expression(wgn, mod, inp.value)
|
||||
wgn.return_()
|
||||
|
||||
def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None:
|
||||
def statement_if(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementIf) -> None:
|
||||
"""
|
||||
Compile: If statement
|
||||
"""
|
||||
expression(wgn, inp.test)
|
||||
expression(wgn, mod, inp.test)
|
||||
with wgn.if_():
|
||||
for stat in inp.statements:
|
||||
statement(wgn, stat)
|
||||
statement(wgn, mod, stat)
|
||||
|
||||
if inp.else_statements:
|
||||
raise NotImplementedError
|
||||
@ -482,16 +417,16 @@ def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None:
|
||||
# for stat in inp.else_statements:
|
||||
# statement(wgn, stat)
|
||||
|
||||
def statement(wgn: WasmGenerator, inp: ourlang.Statement) -> None:
|
||||
def statement(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Statement) -> None:
|
||||
"""
|
||||
Compile: any statement
|
||||
"""
|
||||
if isinstance(inp, ourlang.StatementReturn):
|
||||
statement_return(wgn, inp)
|
||||
statement_return(wgn, mod, inp)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.StatementIf):
|
||||
statement_if(wgn, inp)
|
||||
statement_if(wgn, mod, inp)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.StatementPass):
|
||||
@ -522,7 +457,7 @@ def import_(inp: ourlang.Function) -> wasm.Import:
|
||||
type3(inp.returns_type3)
|
||||
)
|
||||
|
||||
def function(inp: ourlang.Function) -> wasm.Function:
|
||||
def function(mod: ourlang.Module, inp: ourlang.Function) -> wasm.Function:
|
||||
"""
|
||||
Compile: function
|
||||
"""
|
||||
@ -534,7 +469,7 @@ def function(inp: ourlang.Function) -> wasm.Function:
|
||||
_generate_struct_constructor(wgn, inp)
|
||||
else:
|
||||
for stat in inp.statements:
|
||||
statement(wgn, stat)
|
||||
statement(wgn, mod, stat)
|
||||
|
||||
return wasm.Function(
|
||||
inp.name,
|
||||
@ -724,26 +659,32 @@ def module(inp: ourlang.Module) -> wasm.Module:
|
||||
stdlib_alloc.__find_free_block__,
|
||||
stdlib_alloc.__alloc__,
|
||||
stdlib_types.__alloc_bytes__,
|
||||
stdlib_types.__subscript_bytes__,
|
||||
stdlib_types.__u32_ord_min__,
|
||||
stdlib_types.__u64_ord_min__,
|
||||
stdlib_types.__i32_ord_min__,
|
||||
stdlib_types.__i64_ord_min__,
|
||||
stdlib_types.__u32_ord_max__,
|
||||
stdlib_types.__u64_ord_max__,
|
||||
stdlib_types.__i32_ord_max__,
|
||||
stdlib_types.__i64_ord_max__,
|
||||
stdlib_types.__i32_intnum_abs__,
|
||||
stdlib_types.__i64_intnum_abs__,
|
||||
stdlib_types.__u32_pow2__,
|
||||
stdlib_types.__u8_rotl__,
|
||||
stdlib_types.__u8_rotr__,
|
||||
# stdlib_types.__subscript_bytes__,
|
||||
# stdlib_types.__u32_ord_min__,
|
||||
# stdlib_types.__u64_ord_min__,
|
||||
# stdlib_types.__i32_ord_min__,
|
||||
# stdlib_types.__i64_ord_min__,
|
||||
# stdlib_types.__u32_ord_max__,
|
||||
# stdlib_types.__u64_ord_max__,
|
||||
# stdlib_types.__i32_ord_max__,
|
||||
# stdlib_types.__i64_ord_max__,
|
||||
# stdlib_types.__i32_intnum_abs__,
|
||||
# stdlib_types.__i64_intnum_abs__,
|
||||
# stdlib_types.__u32_pow2__,
|
||||
# stdlib_types.__u8_rotl__,
|
||||
# stdlib_types.__u8_rotr__,
|
||||
] + [
|
||||
function(x)
|
||||
function(inp, x)
|
||||
for x in inp.functions.values()
|
||||
if not x.imported
|
||||
]
|
||||
|
||||
# Do this after rendering the functions since that's what populates the tables
|
||||
result.table = {
|
||||
v: k.name
|
||||
for k, v in inp.functions_table.items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None:
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Contains the syntax tree for ourlang
|
||||
"""
|
||||
import enum
|
||||
from typing import Dict, Iterable, List, Optional, Union
|
||||
|
||||
from . import prelude
|
||||
@ -161,6 +160,18 @@ class FunctionCall(Expression):
|
||||
self.function = function
|
||||
self.arguments = []
|
||||
|
||||
class FunctionReference(Expression):
|
||||
"""
|
||||
An function reference expression within a statement
|
||||
"""
|
||||
__slots__ = ('function', )
|
||||
|
||||
function: 'Function'
|
||||
|
||||
def __init__(self, function: 'Function') -> None:
|
||||
super().__init__()
|
||||
self.function = function
|
||||
|
||||
class TupleInstantiation(Expression):
|
||||
"""
|
||||
Instantiation a tuple
|
||||
@ -207,36 +218,6 @@ class AccessStructMember(Expression):
|
||||
self.struct_type3 = struct_type3
|
||||
self.member = member
|
||||
|
||||
class Fold(Expression):
|
||||
"""
|
||||
A (left or right) fold
|
||||
"""
|
||||
class Direction(enum.Enum):
|
||||
"""
|
||||
Which direction to fold in
|
||||
"""
|
||||
LEFT = 0
|
||||
RIGHT = 1
|
||||
|
||||
dir: Direction
|
||||
func: 'Function'
|
||||
base: Expression
|
||||
iter: Expression
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dir_: Direction,
|
||||
func: 'Function',
|
||||
base: Expression,
|
||||
iter_: Expression,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dir = dir_
|
||||
self.func = func
|
||||
self.base = base
|
||||
self.iter = iter_
|
||||
|
||||
class Statement:
|
||||
"""
|
||||
A statement within a function
|
||||
@ -397,13 +378,14 @@ class Module:
|
||||
"""
|
||||
A module is a file and consists of functions
|
||||
"""
|
||||
__slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', )
|
||||
__slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'functions_table', 'operators', )
|
||||
|
||||
data: ModuleData
|
||||
types: dict[str, Type3]
|
||||
struct_definitions: Dict[str, StructDefinition]
|
||||
constant_defs: Dict[str, ModuleConstantDef]
|
||||
functions: Dict[str, Function]
|
||||
functions_table: dict[Function, int]
|
||||
operators: Dict[str, Type3ClassMethod]
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -413,3 +395,4 @@ class Module:
|
||||
self.constant_defs = {}
|
||||
self.functions = {}
|
||||
self.operators = {}
|
||||
self.functions_table = {}
|
||||
|
||||
@ -14,10 +14,10 @@ from .ourlang import (
|
||||
ConstantStruct,
|
||||
ConstantTuple,
|
||||
Expression,
|
||||
Fold,
|
||||
Function,
|
||||
FunctionCall,
|
||||
FunctionParam,
|
||||
FunctionReference,
|
||||
Module,
|
||||
ModuleConstantDef,
|
||||
ModuleDataBlock,
|
||||
@ -446,6 +446,9 @@ class OurVisitor:
|
||||
cdef = module.constant_defs[node.id]
|
||||
return VariableReference(cdef)
|
||||
|
||||
if node.id in module.functions:
|
||||
return FunctionReference(module.functions[node.id])
|
||||
|
||||
_raise_static_error(node, f'Undefined variable {node.id}')
|
||||
|
||||
if isinstance(node, ast.Tuple):
|
||||
@ -462,7 +465,7 @@ class OurVisitor:
|
||||
|
||||
raise NotImplementedError(f'{node} as expr in FunctionDef')
|
||||
|
||||
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]:
|
||||
def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]:
|
||||
if node.keywords:
|
||||
_raise_static_error(node, 'Keyword calling not supported') # Yet?
|
||||
|
||||
@ -475,28 +478,6 @@ class OurVisitor:
|
||||
|
||||
if node.func.id in PRELUDE_METHODS:
|
||||
func = PRELUDE_METHODS[node.func.id]
|
||||
elif node.func.id == 'foldl':
|
||||
if 3 != len(node.args):
|
||||
_raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given')
|
||||
|
||||
# TODO: This is not generic, you cannot return a function
|
||||
subnode = node.args[0]
|
||||
if not isinstance(subnode, ast.Name):
|
||||
raise NotImplementedError(f'Calling methods that are not a name {subnode}')
|
||||
if not isinstance(subnode.ctx, ast.Load):
|
||||
_raise_static_error(subnode, 'Must be load context')
|
||||
if subnode.id not in module.functions:
|
||||
_raise_static_error(subnode, 'Reference to undefined function')
|
||||
func = module.functions[subnode.id]
|
||||
if 2 != len(func.posonlyargs):
|
||||
_raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given')
|
||||
|
||||
return Fold(
|
||||
Fold.Direction.LEFT,
|
||||
func,
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]),
|
||||
self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]),
|
||||
)
|
||||
else:
|
||||
if node.func.id not in module.functions:
|
||||
_raise_static_error(node, 'Call to undefined function')
|
||||
|
||||
@ -20,6 +20,7 @@ from ..type3.types import (
|
||||
Type3,
|
||||
TypeApplication_Nullary,
|
||||
TypeConstructor_Base,
|
||||
TypeConstructor_Function,
|
||||
TypeConstructor_StaticArray,
|
||||
TypeConstructor_Struct,
|
||||
TypeConstructor_Tuple,
|
||||
@ -186,6 +187,16 @@ It should be applied with zero or more arguments. It has a compile time
|
||||
determined length, and each argument can be different.
|
||||
"""
|
||||
|
||||
def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None:
|
||||
pass # ? instance_type_class(InternalPassAsPointer, typ)
|
||||
|
||||
function = TypeConstructor_Function('function', on_create=fn_on_create)
|
||||
"""
|
||||
This is a function.
|
||||
|
||||
It should be applied with one or more arguments. The last argument is the 'return' type.
|
||||
"""
|
||||
|
||||
def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None:
|
||||
instance_type_class(InternalPassAsPointer, typ)
|
||||
|
||||
@ -574,12 +585,16 @@ instance_type_class(Promotable, f32, f64, methods={
|
||||
|
||||
Foldable = Type3Class('Foldable', (t, ), methods={
|
||||
'sum': [t(a), a],
|
||||
'foldl': [[b, a, b], b, t(a), b],
|
||||
'foldr': [[a, b, b], b, t(a), b],
|
||||
}, operators={}, additional_context={
|
||||
'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))],
|
||||
})
|
||||
|
||||
instance_type_class(Foldable, static_array, methods={
|
||||
'sum': stdtypes.static_array_sum,
|
||||
'foldl': stdtypes.static_array_foldl,
|
||||
'foldr': stdtypes.static_array_foldr,
|
||||
})
|
||||
|
||||
PRELUDE_TYPE_CLASSES = {
|
||||
|
||||
@ -4,7 +4,7 @@ stdlib: Standard types that are not wasm primitives
|
||||
from phasm.stdlib import alloc
|
||||
from phasm.type3.routers import TypeVariableLookup
|
||||
from phasm.type3.types import IntType3, Type3
|
||||
from phasm.wasmgenerator import Generator, func_wrapper
|
||||
from phasm.wasmgenerator import Generator, VarType_Base, func_wrapper
|
||||
from phasm.wasmgenerator import VarType_i32 as i32
|
||||
from phasm.wasmgenerator import VarType_i64 as i64
|
||||
|
||||
@ -1081,9 +1081,17 @@ def f32_f64_demote(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.f32.demote_f64()
|
||||
|
||||
def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
assert len(tv_map) == 1
|
||||
sa_type, sa_len = next(iter(tv_map.values()))
|
||||
def static_array_sum(g: Generator, tvl: TypeVariableLookup) -> None:
|
||||
tv_map, tc_map = tvl
|
||||
|
||||
tvn_map = {
|
||||
x.name: y
|
||||
for x, y in tv_map.items()
|
||||
}
|
||||
|
||||
sa_type = tvn_map['a']
|
||||
sa_len = tvn_map['a*']
|
||||
|
||||
assert isinstance(sa_type, Type3)
|
||||
assert isinstance(sa_len, IntType3)
|
||||
|
||||
@ -1166,7 +1174,7 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
g.nop(comment='Add array value')
|
||||
g.local.get(sum_adr)
|
||||
g.add_statement(f'{sa_type_mtyp}.load')
|
||||
sa_type_add_gen(g, {})
|
||||
sa_type_add_gen(g, ({}, {}, ))
|
||||
|
||||
# adr = adr + sa_type_alloc_size
|
||||
# Stack: [sum] -> [sum]
|
||||
@ -1185,3 +1193,253 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
|
||||
g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]')
|
||||
# End result: [sum]
|
||||
|
||||
def static_array_foldl(g: Generator, tvl: TypeVariableLookup) -> None:
|
||||
tv_map, tc_map = tvl
|
||||
|
||||
tvn_map = {
|
||||
x.name: y
|
||||
for x, y in tv_map.items()
|
||||
}
|
||||
|
||||
sa_type = tvn_map['a']
|
||||
sa_len = tvn_map['a*']
|
||||
res_type = tvn_map['b']
|
||||
|
||||
assert isinstance(sa_type, Type3)
|
||||
assert isinstance(sa_len, IntType3)
|
||||
assert isinstance(res_type, Type3)
|
||||
|
||||
if sa_len.value < 1:
|
||||
raise NotImplementedError('Default value in case foldl is empty')
|
||||
|
||||
# FIXME: We should probably use LOAD_STORE_TYPE_MAP for this?
|
||||
mtyp_map = {
|
||||
'u32': 'i32',
|
||||
'u64': 'i64',
|
||||
'i32': 'i32',
|
||||
'i64': 'i64',
|
||||
'f32': 'f32',
|
||||
'f64': 'f64',
|
||||
}
|
||||
mtyp_f_map: dict[str, type[VarType_Base]] = {
|
||||
'i32': i32,
|
||||
'i64': i64,
|
||||
}
|
||||
|
||||
# FIXME: We should probably use calc_alloc_size for this?
|
||||
type_var_size_map = {
|
||||
'u32': 4,
|
||||
'u64': 8,
|
||||
'i32': 4,
|
||||
'i64': 8,
|
||||
'f32': 4,
|
||||
'f64': 8,
|
||||
}
|
||||
|
||||
# By default, constructed types are passed as pointers
|
||||
# FIXME: We don't know what add function to call
|
||||
sa_type_mtyp = mtyp_map.get(sa_type.name, 'i32')
|
||||
sa_type_alloc_size = type_var_size_map.get(sa_type.name, 4)
|
||||
res_type_mtyp = mtyp_map.get(res_type.name, 'i32')
|
||||
res_type_mtyp_f = mtyp_f_map[res_type_mtyp]
|
||||
|
||||
# Definitions
|
||||
fold_adr = g.temp_var(i32('fold_adr'))
|
||||
fold_stop = g.temp_var(i32('fold_stop'))
|
||||
fold_init = g.temp_var(res_type_mtyp_f('fold_init'))
|
||||
fold_func = g.temp_var(i32('fold_func'))
|
||||
|
||||
with g.block(params=['i32', res_type_mtyp, 'i32'], result=res_type_mtyp, comment=f'foldl a={sa_type.name} a*={sa_len.value} b={res_type.name}'):
|
||||
# Stack: [[a, b, b], b, t(a), b]
|
||||
# t(a) == sa_type[sa_len]
|
||||
# Stack: [fn*, b, sa*]
|
||||
|
||||
# adr = {address of what's currently on stack}
|
||||
# Stack: [fn*, b, sa*] -> [fn*, b]
|
||||
g.local.set(fold_adr)
|
||||
# Stack: [fn*, b] -> [fn*]
|
||||
g.local.set(fold_init)
|
||||
# Stack: [fn*] -> []
|
||||
g.local.set(fold_func)
|
||||
|
||||
# stop = adr + ar_len * sa_type_alloc_size
|
||||
# Stack: []
|
||||
g.nop(comment='Calculate address at which to stop looping')
|
||||
g.local.get(fold_adr)
|
||||
g.i32.const(sa_len.value * sa_type_alloc_size)
|
||||
g.i32.add()
|
||||
g.local.set(fold_stop)
|
||||
|
||||
# Stack: [] -> [b]
|
||||
g.nop(comment='Get the init value and first array value as starting point')
|
||||
g.local.get(fold_init)
|
||||
# Stack: [b] -> [b, *a]
|
||||
g.local.get(fold_adr)
|
||||
# Stack: [b] -> [b, a]
|
||||
g.add_statement(f'{sa_type_mtyp}.load')
|
||||
g.nop(comment='Call the fold function')
|
||||
g.local.get(fold_func)
|
||||
g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})')
|
||||
|
||||
# adr = adr + sa_type_alloc_size
|
||||
# Stack: [b] -> [b]
|
||||
g.nop(comment='Calculate address of the next value')
|
||||
g.local.get(fold_adr)
|
||||
g.i32.const(sa_type_alloc_size)
|
||||
g.i32.add()
|
||||
g.local.set(fold_adr)
|
||||
|
||||
if sa_len.value > 1:
|
||||
with g.loop(params=[sa_type_mtyp], result=sa_type_mtyp):
|
||||
# Stack: [b] -> [b, a]
|
||||
g.nop(comment='Add array value')
|
||||
g.local.get(fold_adr)
|
||||
g.add_statement(f'{sa_type_mtyp}.load')
|
||||
|
||||
# Stack [b, a] -> b
|
||||
g.nop(comment='Call the fold function')
|
||||
g.local.get(fold_func)
|
||||
g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})')
|
||||
|
||||
# adr = adr + sa_type_alloc_size
|
||||
# Stack: [fold] -> [fold]
|
||||
g.nop(comment='Calculate address of the next value')
|
||||
g.local.get(fold_adr)
|
||||
g.i32.const(sa_type_alloc_size)
|
||||
g.i32.add()
|
||||
g.local.tee(fold_adr)
|
||||
|
||||
# loop if adr < stop
|
||||
g.nop(comment='Check if address exceeds array bounds')
|
||||
g.local.get(fold_stop)
|
||||
g.i32.lt_u()
|
||||
g.br_if(0)
|
||||
# else: just one value, don't need to loop
|
||||
|
||||
# Stack: [b]
|
||||
|
||||
def static_array_foldr(g: Generator, tvl: TypeVariableLookup) -> None:
|
||||
tv_map, tc_map = tvl
|
||||
|
||||
tvn_map = {
|
||||
x.name: y
|
||||
for x, y in tv_map.items()
|
||||
}
|
||||
|
||||
sa_type = tvn_map['a']
|
||||
sa_len = tvn_map['a*']
|
||||
res_type = tvn_map['b']
|
||||
|
||||
assert isinstance(sa_type, Type3)
|
||||
assert isinstance(sa_len, IntType3)
|
||||
assert isinstance(res_type, Type3)
|
||||
|
||||
if sa_len.value < 1:
|
||||
raise NotImplementedError('Default value in case foldl is empty')
|
||||
|
||||
# FIXME: We should probably use LOAD_STORE_TYPE_MAP for this?
|
||||
mtyp_map = {
|
||||
'u32': 'i32',
|
||||
'u64': 'i64',
|
||||
'i32': 'i32',
|
||||
'i64': 'i64',
|
||||
'f32': 'f32',
|
||||
'f64': 'f64',
|
||||
}
|
||||
mtyp_f_map: dict[str, type[VarType_Base]] = {
|
||||
'i32': i32,
|
||||
'i64': i64,
|
||||
}
|
||||
|
||||
# FIXME: We should probably use calc_alloc_size for this?
|
||||
type_var_size_map = {
|
||||
'u32': 4,
|
||||
'u64': 8,
|
||||
'i32': 4,
|
||||
'i64': 8,
|
||||
'f32': 4,
|
||||
'f64': 8,
|
||||
}
|
||||
|
||||
# By default, constructed types are passed as pointers
|
||||
# FIXME: We don't know what add function to call
|
||||
sa_type_mtyp = mtyp_map.get(sa_type.name, 'i32')
|
||||
sa_type_alloc_size = type_var_size_map.get(sa_type.name, 4)
|
||||
res_type_mtyp = mtyp_map.get(res_type.name, 'i32')
|
||||
res_type_mtyp_f = mtyp_f_map[res_type_mtyp]
|
||||
|
||||
# Definitions
|
||||
fold_adr = g.temp_var(i32('fold_adr'))
|
||||
fold_stop = g.temp_var(i32('fold_stop'))
|
||||
fold_init = g.temp_var(res_type_mtyp_f('fold_init'))
|
||||
fold_func = g.temp_var(i32('fold_func'))
|
||||
|
||||
with g.block(params=['i32', res_type_mtyp, 'i32'], result=res_type_mtyp, comment=f'foldl a={sa_type.name} a*={sa_len.value} b={res_type.name}'):
|
||||
# Stack: [[a, b, b], b, t(a), b]
|
||||
# t(a) == sa_type[sa_len]
|
||||
# Stack: [fn*, b, sa*]
|
||||
|
||||
# adr = {address of what's currently on stack}
|
||||
# Stack: [fn*, b, sa*] -> [fn*, b]
|
||||
g.local.set(fold_adr)
|
||||
# Stack: [fn*, b] -> [fn*]
|
||||
g.local.set(fold_init)
|
||||
# Stack: [fn*] -> []
|
||||
g.local.set(fold_func)
|
||||
|
||||
# stop = adr + ar_len * sa_type_alloc_size
|
||||
# Stack: []
|
||||
g.nop(comment='Calculate address at which to stop looping')
|
||||
g.local.get(fold_adr)
|
||||
g.i32.const(sa_len.value * sa_type_alloc_size)
|
||||
g.i32.add()
|
||||
g.local.set(fold_stop)
|
||||
|
||||
# Stack: [] -> [b]
|
||||
g.nop(comment='Get the init value and first array value as starting point')
|
||||
g.local.get(fold_init)
|
||||
# Stack: [b] -> [b, *a]
|
||||
g.local.get(fold_adr)
|
||||
# Stack: [b] -> [b, a]
|
||||
g.add_statement(f'{sa_type_mtyp}.load')
|
||||
g.nop(comment='Call the fold function')
|
||||
g.local.get(fold_func)
|
||||
g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})')
|
||||
|
||||
# adr = adr + sa_type_alloc_size
|
||||
# Stack: [b] -> [b]
|
||||
g.nop(comment='Calculate address of the next value')
|
||||
g.local.get(fold_adr)
|
||||
g.i32.const(sa_type_alloc_size)
|
||||
g.i32.add()
|
||||
g.local.set(fold_adr)
|
||||
|
||||
if sa_len.value > 1:
|
||||
with g.loop(params=[sa_type_mtyp], result=sa_type_mtyp):
|
||||
# Stack: [b] -> [b, a]
|
||||
g.nop(comment='Add array value')
|
||||
g.local.get(fold_adr)
|
||||
g.add_statement(f'{sa_type_mtyp}.load')
|
||||
|
||||
# Stack [b, a] -> b
|
||||
g.nop(comment='Call the fold function')
|
||||
g.local.get(fold_func)
|
||||
g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})')
|
||||
|
||||
# adr = adr + sa_type_alloc_size
|
||||
# Stack: [fold] -> [fold]
|
||||
g.nop(comment='Calculate address of the next value')
|
||||
g.local.get(fold_adr)
|
||||
g.i32.const(sa_type_alloc_size)
|
||||
g.i32.add()
|
||||
g.local.tee(fold_adr)
|
||||
|
||||
# loop if adr < stop
|
||||
g.nop(comment='Check if address exceeds array bounds')
|
||||
g.local.get(fold_stop)
|
||||
g.i32.lt_u()
|
||||
g.br_if(0)
|
||||
# else: just one value, don't need to loop
|
||||
|
||||
# Stack: [b]
|
||||
|
||||
@ -6,6 +6,7 @@ These need to be resolved before the program can be compiled.
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
from .. import ourlang, prelude
|
||||
from .functions import FunctionArgument, TypeVariable
|
||||
from .placeholders import PlaceholderForType, Type3OrPlaceholder
|
||||
from .routers import NoRouteForTypeException, TypeApplicationRouter
|
||||
from .typeclasses import Type3Class
|
||||
@ -158,7 +159,7 @@ class SameTypeConstraint(ConstraintBase):
|
||||
return (
|
||||
' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))),
|
||||
{
|
||||
't' + str(idx): typ
|
||||
't' + str(idx): typ.name if isinstance(typ, ourlang.Function) else typ
|
||||
for idx, typ in enumerate(self.type_list)
|
||||
},
|
||||
)
|
||||
@ -181,7 +182,7 @@ class SameTypeArgumentConstraint(ConstraintBase):
|
||||
self.arg_var = arg_var
|
||||
|
||||
def check(self) -> CheckResult:
|
||||
if self.tc_var.resolve_as is None or self.arg_var.resolve_as is None:
|
||||
if self.tc_var.resolve_as is None:
|
||||
return RequireTypeSubstitutes()
|
||||
|
||||
tc_typ = self.tc_var.resolve_as
|
||||
@ -201,13 +202,84 @@ class SameTypeArgumentConstraint(ConstraintBase):
|
||||
# FIXME: This feels sketchy. Shouldn't the type variable
|
||||
# have the exact same number as arguments?
|
||||
if isinstance(tc_typ.application, TypeApplication_TypeInt):
|
||||
if tc_typ.application.arguments[0] == arg_typ:
|
||||
return None
|
||||
|
||||
return Error(f'{tc_typ.application.arguments[0]:s} must be {arg_typ:s} instead')
|
||||
return [SameTypeConstraint(
|
||||
tc_typ.application.arguments[0],
|
||||
self.arg_var,
|
||||
comment=self.comment,
|
||||
)]
|
||||
|
||||
raise NotImplementedError(tc_typ, arg_typ)
|
||||
|
||||
def human_readable(self) -> HumanReadableRet:
|
||||
return (
|
||||
'{tc_var}` == {arg_var}',
|
||||
{
|
||||
'tc_var': self.tc_var if self.tc_var.resolve_as is None else self.tc_var,
|
||||
'arg_var': self.arg_var if self.arg_var.resolve_as is None else self.arg_var,
|
||||
},
|
||||
)
|
||||
|
||||
class SameFunctionArgumentConstraint(ConstraintBase):
|
||||
__slots__ = ('type3', 'func_arg', 'type_var_map', )
|
||||
|
||||
type3: PlaceholderForType
|
||||
func_arg: FunctionArgument
|
||||
type_var_map: dict[TypeVariable, PlaceholderForType]
|
||||
|
||||
def __init__(self, type3: PlaceholderForType, func_arg: FunctionArgument, type_var_map: dict[TypeVariable, PlaceholderForType], *, comment: str) -> None:
|
||||
super().__init__(comment=comment)
|
||||
|
||||
self.type3 = type3
|
||||
self.func_arg = func_arg
|
||||
self.type_var_map = type_var_map
|
||||
|
||||
def check(self) -> CheckResult:
|
||||
if self.type3.resolve_as is None:
|
||||
return RequireTypeSubstitutes()
|
||||
|
||||
typ = self.type3.resolve_as
|
||||
|
||||
if isinstance(typ.application, TypeApplication_Nullary):
|
||||
return Error(f'{typ:s} must be a function instead')
|
||||
|
||||
if not isinstance(typ.application, TypeApplication_TypeStar):
|
||||
return Error(f'{typ:s} must be a function instead')
|
||||
|
||||
type_var_map = {
|
||||
x: y.resolve_as
|
||||
for x, y in self.type_var_map.items()
|
||||
if y.resolve_as is not None
|
||||
}
|
||||
|
||||
exp_type_arg_list = [
|
||||
tv if isinstance(tv, Type3) else type_var_map[tv]
|
||||
for tv in self.func_arg.args
|
||||
if isinstance(tv, Type3) or tv in type_var_map
|
||||
]
|
||||
|
||||
print('self.func_arg.args', self.func_arg.args)
|
||||
print('exp_type_arg_list', exp_type_arg_list)
|
||||
|
||||
if len(exp_type_arg_list) != len(self.func_arg.args):
|
||||
return RequireTypeSubstitutes()
|
||||
|
||||
return [
|
||||
SameTypeConstraint(
|
||||
typ,
|
||||
prelude.function(*exp_type_arg_list),
|
||||
comment=self.comment,
|
||||
)
|
||||
]
|
||||
|
||||
def human_readable(self) -> HumanReadableRet:
|
||||
return (
|
||||
'{type3} == {func_arg}',
|
||||
{
|
||||
'type3': self.type3,
|
||||
'func_arg': self.func_arg.name,
|
||||
},
|
||||
)
|
||||
|
||||
class TupleMatchConstraint(ConstraintBase):
|
||||
__slots__ = ('exp_type', 'args', )
|
||||
|
||||
@ -264,14 +336,10 @@ class MustImplementTypeClassConstraint(ConstraintBase):
|
||||
__slots__ = ('context', 'type_class3', 'types', )
|
||||
|
||||
context: Context
|
||||
type_class3: Union[str, Type3Class]
|
||||
type_class3: Type3Class
|
||||
types: list[Type3OrPlaceholder]
|
||||
|
||||
DATA = {
|
||||
'bytes': {'Foldable'},
|
||||
}
|
||||
|
||||
def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
|
||||
def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None:
|
||||
super().__init__(comment=comment)
|
||||
|
||||
self.context = context
|
||||
@ -299,13 +367,9 @@ class MustImplementTypeClassConstraint(ConstraintBase):
|
||||
|
||||
assert len(typ_list) == len(self.types)
|
||||
|
||||
if isinstance(self.type_class3, Type3Class):
|
||||
key = (self.type_class3, tuple(typ_list), )
|
||||
if key in self.context.type_class_instances_existing:
|
||||
return None
|
||||
else:
|
||||
if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()):
|
||||
return None
|
||||
key = (self.type_class3, tuple(typ_list), )
|
||||
if key in self.context.type_class_instances_existing:
|
||||
return None
|
||||
|
||||
typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name
|
||||
typ_name_list = ' '.join(x.name for x in typ_list)
|
||||
|
||||
@ -12,12 +12,14 @@ from .constraints import (
|
||||
Context,
|
||||
LiteralFitsConstraint,
|
||||
MustImplementTypeClassConstraint,
|
||||
SameFunctionArgumentConstraint,
|
||||
SameTypeArgumentConstraint,
|
||||
SameTypeConstraint,
|
||||
TupleMatchConstraint,
|
||||
)
|
||||
from .functions import (
|
||||
Constraint_TypeClassInstanceExists,
|
||||
FunctionArgument,
|
||||
FunctionSignature,
|
||||
TypeVariable,
|
||||
TypeVariableApplication_Unary,
|
||||
@ -111,6 +113,33 @@ def _expression_function_call(
|
||||
|
||||
raise NotImplementedError(constraint)
|
||||
|
||||
func_var_map = {
|
||||
x: PlaceholderForType([])
|
||||
for x in signature.args
|
||||
if isinstance(x, FunctionArgument)
|
||||
}
|
||||
|
||||
# If some of the function arguments are functions,
|
||||
# we need to deal with those separately.
|
||||
for sig_arg in signature.args:
|
||||
if not isinstance(sig_arg, FunctionArgument):
|
||||
continue
|
||||
|
||||
# Ensure that for all type variables in the function
|
||||
# there are also type variables available
|
||||
for func_arg in sig_arg.args:
|
||||
if isinstance(func_arg, Type3):
|
||||
continue
|
||||
|
||||
type_var_map.setdefault(func_arg, PlaceholderForType([]))
|
||||
|
||||
yield SameFunctionArgumentConstraint(
|
||||
func_var_map[sig_arg],
|
||||
sig_arg,
|
||||
type_var_map,
|
||||
comment=f'Ensure `{sig_arg.name}` matches in {signature}',
|
||||
)
|
||||
|
||||
# If some of the function arguments are type constructors,
|
||||
# we need to deal with those separately.
|
||||
# That is, given `foo :: t a -> a` we need to ensure
|
||||
@ -120,6 +149,9 @@ def _expression_function_call(
|
||||
# Not a type variable at all
|
||||
continue
|
||||
|
||||
if isinstance(sig_arg, FunctionArgument):
|
||||
continue
|
||||
|
||||
if sig_arg.application.constructor is None:
|
||||
# Not a type variable for a type constructor
|
||||
continue
|
||||
@ -150,9 +182,20 @@ def _expression_function_call(
|
||||
yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment)
|
||||
continue
|
||||
|
||||
if isinstance(sig_part, FunctionArgument):
|
||||
yield SameTypeConstraint(func_var_map[sig_part], arg_placeholders[arg_expr], comment=comment)
|
||||
continue
|
||||
|
||||
raise NotImplementedError(sig_part)
|
||||
return
|
||||
|
||||
def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator:
|
||||
yield SameTypeConstraint(
|
||||
prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3),
|
||||
phft,
|
||||
comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})',
|
||||
)
|
||||
|
||||
def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator:
|
||||
if isinstance(inp, ourlang.Constant):
|
||||
yield from constant(ctx, inp, phft)
|
||||
@ -171,6 +214,10 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
|
||||
yield from expression_function_call(ctx, inp, phft)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.FunctionReference):
|
||||
yield from expression_function_reference(ctx, inp, phft)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.TupleInstantiation):
|
||||
r_type = []
|
||||
for arg in inp.elements:
|
||||
@ -209,19 +256,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType)
|
||||
comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}')
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.Fold):
|
||||
base_phft = PlaceholderForType([inp.base])
|
||||
iter_phft = PlaceholderForType([inp.iter])
|
||||
|
||||
yield from expression(ctx, inp.base, base_phft)
|
||||
yield from expression(ctx, inp.iter, iter_phft)
|
||||
|
||||
yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft,
|
||||
comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b')
|
||||
yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft])
|
||||
|
||||
return
|
||||
|
||||
raise NotImplementedError(expression, inp)
|
||||
|
||||
def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator:
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Hashable, Iterable, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .typeclasses import Type3Class
|
||||
@ -155,15 +157,29 @@ class TypeVariableContext:
|
||||
def __repr__(self) -> str:
|
||||
return f'TypeVariableContext({self.constraints!r})'
|
||||
|
||||
class FunctionArgument:
|
||||
__slots__ = ('args', 'name', )
|
||||
|
||||
args: list[Type3 | TypeVariable]
|
||||
name: str
|
||||
|
||||
def __init__(self, args: list[Type3 | TypeVariable]) -> None:
|
||||
self.args = args
|
||||
|
||||
self.name = '(' + ' -> '.join(x.name for x in args) + ')'
|
||||
|
||||
class FunctionSignature:
|
||||
__slots__ = ('context', 'args', )
|
||||
|
||||
context: TypeVariableContext
|
||||
args: List[Union['Type3', TypeVariable]]
|
||||
args: List[Type3 | TypeVariable | FunctionArgument]
|
||||
|
||||
def __init__(self, context: TypeVariableContext, args: Iterable[Union['Type3', TypeVariable]]) -> None:
|
||||
def __init__(self, context: TypeVariableContext, args: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]) -> None:
|
||||
self.context = context.__copy__()
|
||||
self.args = list(args)
|
||||
self.args = list(
|
||||
FunctionArgument(x) if isinstance(x, list) else x
|
||||
for x in args
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return str(self.context) + ' -> '.join(x.name for x in self.args)
|
||||
|
||||
@ -3,6 +3,7 @@ from typing import Any, Callable
|
||||
from .functions import (
|
||||
TypeConstructorVariable,
|
||||
TypeVariable,
|
||||
TypeVariableApplication_Nullary,
|
||||
TypeVariableApplication_Unary,
|
||||
)
|
||||
from .typeclasses import Type3ClassArgs
|
||||
@ -54,7 +55,10 @@ class TypeApplicationRouter[S, R]:
|
||||
|
||||
raise NoRouteForTypeException(arg0, typ)
|
||||
|
||||
TypeVariableLookup = dict[TypeVariable, tuple[KindArgument, ...]]
|
||||
TypeVariableLookup = tuple[
|
||||
dict[TypeVariable, KindArgument],
|
||||
dict[TypeConstructorVariable, TypeConstructor_Base[Any]],
|
||||
]
|
||||
|
||||
class TypeClassArgsRouter[S, R]:
|
||||
"""
|
||||
@ -89,11 +93,12 @@ class TypeClassArgsRouter[S, R]:
|
||||
|
||||
def __call__(self, arg0: S, tv_map: dict[TypeVariable, Type3]) -> R:
|
||||
key: list[Type3 | TypeConstructor_Base[Any]] = []
|
||||
arguments: TypeVariableLookup = {}
|
||||
arguments: TypeVariableLookup = (dict(tv_map), {}, )
|
||||
|
||||
for tc_arg in self.args:
|
||||
if isinstance(tc_arg, TypeVariable):
|
||||
key.append(tv_map[tc_arg])
|
||||
arguments[0][tc_arg] = tv_map[tc_arg]
|
||||
continue
|
||||
|
||||
for tvar, typ in tv_map.items():
|
||||
@ -102,16 +107,24 @@ class TypeClassArgsRouter[S, R]:
|
||||
continue
|
||||
|
||||
key.append(typ.application.constructor)
|
||||
arguments[1][tc_arg] = typ.application.constructor
|
||||
|
||||
if isinstance(tvar.application, TypeVariableApplication_Unary):
|
||||
# FIXME: This feels sketchy. Shouldn't the type variable
|
||||
# have the exact same number as arguments?
|
||||
if isinstance(typ.application, TypeApplication_TypeInt):
|
||||
arguments[tvar.application.arguments] = typ.application.arguments
|
||||
sa_type, sa_len = typ.application.arguments
|
||||
sa_type_tv = tvar.application.arguments
|
||||
sa_len_tv = TypeVariable(sa_type_tv.name + '*', TypeVariableApplication_Nullary(None, None))
|
||||
|
||||
arguments[0][sa_type_tv] = sa_type
|
||||
arguments[0][sa_len_tv] = sa_len
|
||||
continue
|
||||
|
||||
raise NotImplementedError(tvar.application, typ.application)
|
||||
|
||||
continue
|
||||
|
||||
t_helper = self.data.get(tuple(key))
|
||||
if t_helper is not None:
|
||||
return t_helper(arg0, arguments)
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Union
|
||||
from typing import Dict, Iterable, List, Mapping, Optional
|
||||
|
||||
from .functions import (
|
||||
Constraint_TypeClassInstanceExists,
|
||||
@ -42,8 +42,8 @@ class Type3Class:
|
||||
self,
|
||||
name: str,
|
||||
args: Type3ClassArgs,
|
||||
methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]],
|
||||
operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]],
|
||||
methods: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
|
||||
operators: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]],
|
||||
inherited_classes: Optional[List['Type3Class']] = None,
|
||||
additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None,
|
||||
) -> None:
|
||||
@ -71,19 +71,23 @@ class Type3Class:
|
||||
return self.name
|
||||
|
||||
def _create_signature(
|
||||
method_arg_list: Iterable[Type3 | TypeVariable],
|
||||
method_arg_list: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]],
|
||||
type_class3: Type3Class,
|
||||
) -> FunctionSignature:
|
||||
context = TypeVariableContext()
|
||||
if not isinstance(type_class3.args[0], TypeConstructorVariable):
|
||||
context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args))
|
||||
|
||||
signature_args: list[Type3 | TypeVariable] = []
|
||||
signature_args: list[Type3 | TypeVariable | list[Type3 | TypeVariable]] = []
|
||||
for method_arg in method_arg_list:
|
||||
if isinstance(method_arg, Type3):
|
||||
signature_args.append(method_arg)
|
||||
continue
|
||||
|
||||
if isinstance(method_arg, list):
|
||||
signature_args.append(method_arg)
|
||||
continue
|
||||
|
||||
if isinstance(method_arg, TypeVariable):
|
||||
type_constructor = method_arg.application.constructor
|
||||
if type_constructor is None:
|
||||
|
||||
@ -239,6 +239,10 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar):
|
||||
def make_name(self, key: Tuple[Type3, ...]) -> str:
|
||||
return '(' + ', '.join(x.name for x in key) + ', )'
|
||||
|
||||
class TypeConstructor_Function(TypeConstructor_TypeStar):
|
||||
def make_name(self, key: Tuple[Type3, ...]) -> str:
|
||||
return '(' + ' -> '.join(x.name for x in key) + ')'
|
||||
|
||||
class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]):
|
||||
"""
|
||||
Constructs struct types
|
||||
|
||||
@ -187,14 +187,17 @@ class Module(WatSerializable):
|
||||
def __init__(self) -> None:
|
||||
self.imports: List[Import] = []
|
||||
self.functions: List[Function] = []
|
||||
self.table: dict[int, str] = {}
|
||||
self.memory = ModuleMemory()
|
||||
|
||||
def to_wat(self) -> str:
|
||||
"""
|
||||
Generates the text version
|
||||
"""
|
||||
return '(module\n {}\n {}\n {})\n'.format(
|
||||
return '(module\n {}\n {}\n {}\n {}\n {})\n'.format(
|
||||
'\n '.join(x.to_wat() for x in self.imports),
|
||||
f'(table {len(self.table)} funcref)',
|
||||
'\n '.join(f'(elem (i32.const {k}) ${v})' for k, v in self.table.items()),
|
||||
self.memory.to_wat(),
|
||||
'\n '.join(x.to_wat() for x in self.functions),
|
||||
)
|
||||
|
||||
@ -170,11 +170,12 @@ class Generator_Local:
|
||||
self.generator.add_statement('local.tee', variable.name_ref, comment=comment)
|
||||
|
||||
class GeneratorBlock:
|
||||
def __init__(self, generator: 'Generator', name: str, params: Iterable[str] = (), result: str | None = None) -> None:
|
||||
def __init__(self, generator: 'Generator', name: str, params: Iterable[str] = (), result: str | None = None, comment: str | None = None) -> None:
|
||||
self.generator = generator
|
||||
self.name = name
|
||||
self.params = params
|
||||
self.result = result
|
||||
self.comment = comment
|
||||
|
||||
def __enter__(self) -> None:
|
||||
stmt = self.name
|
||||
@ -186,7 +187,7 @@ class GeneratorBlock:
|
||||
if self.result:
|
||||
stmt = f'{stmt} (result {self.result})'
|
||||
|
||||
self.generator.add_statement(stmt)
|
||||
self.generator.add_statement(stmt, comment=self.comment)
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
if not exc_type:
|
||||
@ -208,7 +209,7 @@ class Generator:
|
||||
# 2.4.5 Control Instructions
|
||||
self.nop = functools.partial(self.add_statement, 'nop')
|
||||
self.unreachable = functools.partial(self.add_statement, 'unreachable')
|
||||
# block
|
||||
self.block = functools.partial(GeneratorBlock, self, 'block')
|
||||
self.loop = functools.partial(GeneratorBlock, self, 'loop')
|
||||
self.if_ = functools.partial(GeneratorBlock, self, 'if')
|
||||
# br
|
||||
|
||||
@ -1,61 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from ..helpers import Suite
|
||||
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_foldl_1():
|
||||
code_py = """
|
||||
def u8_or(l: u8, r: u8) -> u8:
|
||||
return l | r
|
||||
|
||||
@exported
|
||||
def testEntry(b: bytes) -> u8:
|
||||
return foldl(u8_or, 128, b)
|
||||
"""
|
||||
suite = Suite(code_py)
|
||||
|
||||
result = suite.run_code(b'')
|
||||
assert 128 == result.returned_value
|
||||
|
||||
result = suite.run_code(b'\x80')
|
||||
assert 128 == result.returned_value
|
||||
|
||||
result = suite.run_code(b'\x80\x40')
|
||||
assert 192 == result.returned_value
|
||||
|
||||
result = suite.run_code(b'\x80\x40\x20\x10')
|
||||
assert 240 == result.returned_value
|
||||
|
||||
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
|
||||
assert 255 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_foldl_2():
|
||||
code_py = """
|
||||
def xor(l: u8, r: u8) -> u8:
|
||||
return l ^ r
|
||||
|
||||
@exported
|
||||
def testEntry(a: bytes, b: bytes) -> u8:
|
||||
return foldl(xor, 0, a) ^ foldl(xor, 0, b)
|
||||
"""
|
||||
suite = Suite(code_py)
|
||||
|
||||
result = suite.run_code(b'\x55\x0F', b'\x33\x80')
|
||||
assert 233 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_foldl_3():
|
||||
code_py = """
|
||||
def xor(l: u32, r: u8) -> u32:
|
||||
return l ^ extend(r)
|
||||
|
||||
@exported
|
||||
def testEntry(a: bytes) -> u32:
|
||||
return foldl(xor, 0, a)
|
||||
"""
|
||||
suite = Suite(code_py)
|
||||
|
||||
result = suite.run_code(b'\x55\x0F\x33\x80')
|
||||
assert 233 == result.returned_value
|
||||
@ -36,6 +36,118 @@ def testEntry(x: Foo[4]) -> Foo:
|
||||
with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@pytest.mark.parametrize('length', [1, 5, 13])
|
||||
@pytest.mark.parametrize('direction', ['foldl', 'foldr'])
|
||||
def test_foldable_foldl_foldr_size(direction, length):
|
||||
code_py = f"""
|
||||
def u64_add(l: u64, r: u64) -> u64:
|
||||
return l + r
|
||||
|
||||
@exported
|
||||
def testEntry(b: u64[{length}]) -> u64:
|
||||
return {direction}(u64_add, 100, b)
|
||||
"""
|
||||
suite = Suite(code_py)
|
||||
|
||||
in_put = tuple(range(1, length + 1))
|
||||
|
||||
result = suite.run_code(in_put)
|
||||
assert (100 + sum(in_put)) == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@pytest.mark.parametrize('direction', ['foldr'])
|
||||
def test_foldable_foldl_foldr_compounded_type(direction):
|
||||
code_py = f"""
|
||||
def combine_foldl(b: u64, a: (u32, u32, )) -> u64:
|
||||
return extend(a[0] * a[1]) + b
|
||||
|
||||
def combine_foldr(a: (u32, u32, ), b: u64) -> u64:
|
||||
return extend(a[0] * a[1]) + b
|
||||
|
||||
@exported
|
||||
def testEntry(b: (u32, u32)[3]) -> u64:
|
||||
return {direction}(combine_{direction}, 10000, b)
|
||||
"""
|
||||
suite = Suite(code_py)
|
||||
|
||||
result = suite.run_code(((2, 5), (25, 4), (125, 8)))
|
||||
assert 11110 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@pytest.mark.parametrize('direction, exp_result', [
|
||||
('foldl', -55, ),
|
||||
('foldr', -5, ),
|
||||
])
|
||||
def test_foldable_foldl_foldr_result(direction, exp_result):
|
||||
# See https://stackoverflow.com/a/13280185
|
||||
code_py = f"""
|
||||
def i32_sub(l: i32, r: i32) -> i32:
|
||||
return l - r
|
||||
|
||||
@exported
|
||||
def testEntry(b: i32[10]) -> i32:
|
||||
return {direction}(i32_sub, 0, b)
|
||||
"""
|
||||
suite = Suite(code_py)
|
||||
|
||||
result = suite.run_code(tuple(range(1, 11)))
|
||||
assert exp_result == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_foldable_foldl_bytes():
|
||||
code_py = """
|
||||
def u8_or(l: u8, r: u8) -> u8:
|
||||
return l | r
|
||||
|
||||
@exported
|
||||
def testEntry(b: bytes) -> u8:
|
||||
return foldl(u8_or, 128, b)
|
||||
"""
|
||||
suite = Suite(code_py)
|
||||
|
||||
result = suite.run_code(b'')
|
||||
assert 128 == result.returned_value
|
||||
|
||||
result = suite.run_code(b'\x80')
|
||||
assert 128 == result.returned_value
|
||||
|
||||
result = suite.run_code(b'\x80\x40')
|
||||
assert 192 == result.returned_value
|
||||
|
||||
result = suite.run_code(b'\x80\x40\x20\x10')
|
||||
assert 240 == result.returned_value
|
||||
|
||||
result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01')
|
||||
assert 255 == result.returned_value
|
||||
|
||||
@pytest.mark.integration_test
|
||||
@pytest.mark.parametrize('in_typ', ['i8', 'i8[3]'])
|
||||
def test_foldable_argument_must_be_a_function(in_typ):
|
||||
code_py = f"""
|
||||
@exported
|
||||
def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32:
|
||||
return foldl(x, y, z)
|
||||
"""
|
||||
|
||||
r_in_typ = in_typ.replace('[', '\\[').replace(']', '\\]')
|
||||
|
||||
with pytest.raises(Type3Exception, match=f'{r_in_typ} must be a function instead'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_foldable_argument_must_be_right_function():
|
||||
code_py = """
|
||||
def foo(l: i32, r: i64) -> i64:
|
||||
return extend(l) + r
|
||||
|
||||
@exported
|
||||
def testEntry(i: i64, l: i64[3]) -> i64:
|
||||
return foldr(foo, i, l)
|
||||
"""
|
||||
|
||||
with pytest.raises(Type3Exception, match=r'\(i64 -> i64 -> i64\) must be \(i32 -> i64 -> i64\) instead'):
|
||||
Suite(code_py).run_code()
|
||||
|
||||
@pytest.mark.integration_test
|
||||
def test_foldable_invalid_return_type():
|
||||
@ -45,7 +157,7 @@ def testEntry(x: i32[5]) -> f64:
|
||||
return sum(x)
|
||||
"""
|
||||
|
||||
with pytest.raises(Type3Exception, match='i32 must be f64 instead'):
|
||||
with pytest.raises(Type3Exception, match='f64 must be i32 instead'):
|
||||
Suite(code_py).run_code((4, 5, 6, 7, 8, ))
|
||||
|
||||
@pytest.mark.integration_test
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user