diff --git a/TODO.md b/TODO.md index f26a9b4..81cd2db 100644 --- a/TODO.md +++ b/TODO.md @@ -31,3 +31,5 @@ - Functions don't seem to be a thing on typing level yet? - Related to the FIXME in phasm_type3? - Type constuctor should also be able to constuct placeholders - somehow. + +- Read https://bytecodealliance.org/articles/multi-value-all-the-wasm diff --git a/phasm/codestyle.py b/phasm/codestyle.py index 3b7e6f4..940ecfa 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -105,10 +105,6 @@ def expression(inp: ourlang.Expression) -> str: if isinstance(inp, ourlang.AccessStructMember): return f'{expression(inp.varref)}.{inp.member}' - if isinstance(inp, ourlang.Fold): - fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' - return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' - raise NotImplementedError(expression, inp) def statement(inp: ourlang.Statement) -> Statements: diff --git a/phasm/compiler.py b/phasm/compiler.py index b758715..4eaf1b2 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -4,12 +4,12 @@ This module contains the code to convert parsed Ourlang into WebAssembly code import struct from typing import List -from . import codestyle, ourlang, prelude, wasm +from . import ourlang, prelude, wasm from .runtime import calculate_alloc_size, calculate_member_offset from .stdlib import alloc as stdlib_alloc from .stdlib import types as stdlib_types from .stdlib.types import TYPE_INFO_CONSTRUCTED, TYPE_INFO_MAP -from .type3.functions import TypeVariable +from .type3.functions import FunctionArgument, TypeVariable from .type3.routers import NoRouteForTypeException, TypeApplicationRouter from .type3.typeclasses import Type3ClassMethod from .type3.types import ( @@ -273,6 +273,10 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) type_var_map[type_var] = arg_expr.type3 continue + if isinstance(type_var, FunctionArgument): + # Fixed type, not part of the lookup requirements + continue + raise NotImplementedError(type_var, arg_expr.type3) router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.operator] @@ -298,6 +302,10 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) type_var_map[type_var] = arg_expr.type3 continue + if isinstance(type_var, FunctionArgument): + # Fixed type, not part of the lookup requirements + continue + raise NotImplementedError(type_var, arg_expr.type3) router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.function] @@ -359,90 +367,8 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) ))) return - if isinstance(inp, ourlang.Fold): - expression_fold(wgn, mod, inp) - return - raise NotImplementedError(expression, inp) -def expression_fold(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Fold) -> None: - """ - Compile: Fold expression - """ - assert inp.type3 is not None, TYPE3_ASSERTION_ERROR - - if inp.iter.type3 is not prelude.bytes_: - raise NotImplementedError(expression_fold, inp, inp.iter.type3) - - wgn.add_statement('nop', comment='acu :: u8') - acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu') - wgn.add_statement('nop', comment='adr :: bytes*') - adr_var = wgn.temp_var_i32('fold_i32_adr') - wgn.add_statement('nop', comment='len :: i32') - len_var = wgn.temp_var_i32('fold_i32_len') - - wgn.add_statement('nop', comment='acu = base') - expression(wgn, mod, inp.base) - wgn.local.set(acu_var) - - wgn.add_statement('nop', comment='adr = adr(iter)') - expression(wgn, mod, inp.iter) - wgn.local.set(adr_var) - - wgn.add_statement('nop', comment='len = len(iter)') - wgn.local.get(adr_var) - wgn.i32.load() - wgn.local.set(len_var) - - wgn.add_statement('nop', comment='i = 0') - idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx') - wgn.i32.const(0) - wgn.local.set(idx_var) - - wgn.add_statement('nop', comment='if i < len') - wgn.local.get(idx_var) - wgn.local.get(len_var) - wgn.i32.lt_u() - with wgn.if_(): - # From here on, adr_var is the address of byte we're referencing - # This is akin to calling stdlib_types.__subscript_bytes__ - # But since we already know we are inside of bounds, - # can just bypass it and load the memory directly. - wgn.local.get(adr_var) - wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop - wgn.i32.add() - wgn.local.set(adr_var) - - wgn.add_statement('nop', comment='while True') - with wgn.loop(): - wgn.add_statement('nop', comment='acu = func(acu, iter[i])') - wgn.local.get(acu_var) - - # Get the next byte, write back the address - wgn.local.get(adr_var) - wgn.i32.const(1) - wgn.i32.add() - wgn.local.tee(adr_var) - wgn.i32.load8_u() - - wgn.add_statement('call', f'${inp.func.name}') - wgn.local.set(acu_var) - - wgn.add_statement('nop', comment='i = i + 1') - wgn.local.get(idx_var) - wgn.i32.const(1) - wgn.i32.add() - wgn.local.set(idx_var) - - wgn.add_statement('nop', comment='if i >= len: break') - wgn.local.get(idx_var) - wgn.local.get(len_var) - wgn.i32.lt_u() - wgn.br_if(0) - - # return acu - wgn.local.get(acu_var) - def statement_return(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementReturn) -> None: """ Compile: Return statement diff --git a/phasm/ourlang.py b/phasm/ourlang.py index 6d828d9..2bd3cac 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -1,7 +1,6 @@ """ Contains the syntax tree for ourlang """ -import enum from typing import Dict, Iterable, List, Optional, Union from . import prelude @@ -219,36 +218,6 @@ class AccessStructMember(Expression): self.struct_type3 = struct_type3 self.member = member -class Fold(Expression): - """ - A (left or right) fold - """ - class Direction(enum.Enum): - """ - Which direction to fold in - """ - LEFT = 0 - RIGHT = 1 - - dir: Direction - func: 'Function' - base: Expression - iter: Expression - - def __init__( - self, - dir_: Direction, - func: 'Function', - base: Expression, - iter_: Expression, - ) -> None: - super().__init__() - - self.dir = dir_ - self.func = func - self.base = base - self.iter = iter_ - class Statement: """ A statement within a function diff --git a/phasm/parser.py b/phasm/parser.py index 4ff906c..e72946d 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -14,7 +14,6 @@ from .ourlang import ( ConstantStruct, ConstantTuple, Expression, - Fold, Function, FunctionCall, FunctionParam, @@ -467,7 +466,7 @@ class OurVisitor: raise NotImplementedError(f'{node} as expr in FunctionDef') - def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]: + def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? @@ -480,28 +479,6 @@ class OurVisitor: if node.func.id in PRELUDE_METHODS: func = PRELUDE_METHODS[node.func.id] - elif node.func.id == 'foldl': - if 3 != len(node.args): - _raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given') - - # TODO: This is not generic, you cannot return a function - subnode = node.args[0] - if not isinstance(subnode, ast.Name): - raise NotImplementedError(f'Calling methods that are not a name {subnode}') - if not isinstance(subnode.ctx, ast.Load): - _raise_static_error(subnode, 'Must be load context') - if subnode.id not in module.functions: - _raise_static_error(subnode, 'Reference to undefined function') - func = module.functions[subnode.id] - if 2 != len(func.posonlyargs): - _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') - - return Fold( - Fold.Direction.LEFT, - func, - self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), - self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), - ) elif node.func.id in our_locals: func = our_locals[node.func.id] else: diff --git a/phasm/prelude/__init__.py b/phasm/prelude/__init__.py index a440e54..5c5afd8 100644 --- a/phasm/prelude/__init__.py +++ b/phasm/prelude/__init__.py @@ -553,12 +553,21 @@ instance_type_class(Promotable, f32, f64, methods={ Foldable = Type3Class('Foldable', (t, ), methods={ 'sum': [t(a), a], + 'foldl': [[b, a, b], b, t(a), b], + 'foldr': [[a, b, b], b, t(a), b], }, operators={}, additional_context={ 'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))], }) +instance_type_class(Foldable, dynamic_array, methods={ + 'sum': stdtypes.dynamic_array_sum, + 'foldl': stdtypes.dynamic_array_foldl, + 'foldr': stdtypes.dynamic_array_foldr, +}) instance_type_class(Foldable, static_array, methods={ 'sum': stdtypes.static_array_sum, + 'foldl': stdtypes.static_array_foldl, + 'foldr': stdtypes.static_array_foldr, }) bytes_ = dynamic_array(u8) diff --git a/phasm/stdlib/types.py b/phasm/stdlib/types.py index 0f41d9c..680228d 100644 --- a/phasm/stdlib/types.py +++ b/phasm/stdlib/types.py @@ -1059,12 +1059,16 @@ def dynamic_array_sized_len(g: Generator, tv_map: TypeVariableLookup) -> None: # The length is stored in the first 4 bytes g.i32.load() -def static_array_sized_len(g: Generator, tv_map: TypeVariableLookup) -> None: - assert len(tv_map) == 1 - sa_type, sa_len = next(iter(tv_map.values())) - assert isinstance(sa_type, Type3) - assert isinstance(sa_len, IntType3) +def static_array_sized_len(g: Generator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_len = tvn_map['a*'] + assert isinstance(sa_len, IntType3) g.i32.const(sa_len.value) ## ### @@ -1137,9 +1141,23 @@ def f32_f64_demote(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map g.f32.demote_f64() -def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: - assert len(tv_map) == 1 - sa_type, sa_len = next(iter(tv_map.values())) +## ### +## Foldable + +def dynamic_array_sum(g: Generator, tvl: TypeVariableLookup) -> None: + raise NotImplementedError('dynamic_array_sum') + +def static_array_sum(g: Generator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + sa_len = tvn_map['a*'] + assert isinstance(sa_type, Type3) assert isinstance(sa_len, IntType3) @@ -1171,7 +1189,7 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: g.nop(comment=f'Start sum for {sa_type.name}[{sa_len.value}]') g.local.set(sum_adr) - # stop = adr + ar_len * sa_type_alloc_size + # stop = adr + ar_len * sa_type_info.alloc_size # Stack: [] g.nop(comment='Calculate address at which to stop looping') g.local.get(sum_adr) @@ -1186,7 +1204,7 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: g.add_statement(sa_type_info.wasm_load_func) # Since we did the first one, increase adr - # adr = adr + sa_type_alloc_size + # adr = adr + sa_type_info.alloc_size # Stack: [sum] -> [sum] g.local.get(sum_adr) g.i32.const(sa_type_info.alloc_size) @@ -1194,15 +1212,15 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: g.local.set(sum_adr) if sa_len.value > 1: - with g.loop(params=[sa_type_info.wasm_type().to_wat()], result=sa_type_info.wasm_type().to_wat()): + with g.loop(params=[sa_type_info.wasm_type], result=sa_type_info.wasm_type): # sum = sum + *adr # Stack: [sum] -> [sum + *adr] g.nop(comment='Add array value') g.local.get(sum_adr) g.add_statement(sa_type_info.wasm_load_func) - sa_type_add_gen(g, {}) + sa_type_add_gen(g, ({}, {}, )) - # adr = adr + sa_type_alloc_size + # adr = adr + sa_type_info.alloc_size # Stack: [sum] -> [sum] g.nop(comment='Calculate address of the next value') g.local.get(sum_adr) @@ -1219,3 +1237,294 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]') # End result: [sum] + +def dynamic_array_foldl(g: Generator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + res_type = tvn_map['b'] + + assert isinstance(sa_type, Type3) + assert isinstance(res_type, Type3) + + sa_type_info = TYPE_INFO_MAP.get(sa_type.name, TYPE_INFO_CONSTRUCTED) + res_type_info = TYPE_INFO_MAP.get(res_type.name, TYPE_INFO_CONSTRUCTED) + + # Definitions + fold_adr = g.temp_var(i32('fold_adr')) + fold_stop = g.temp_var(i32('fold_stop')) + fold_init = g.temp_var_t(res_type_info.wasm_type, 'fold_init') + fold_func = g.temp_var(i32('fold_func')) + fold_len = g.temp_var(i32('fold_len')) + + with g.block(params=['i32', res_type_info.wasm_type, 'i32'], result=res_type_info.wasm_type, comment=f'foldl a={sa_type.name} b={res_type.name}'): + # Stack: [fn*, b, sa*] -> [fn*, b] + g.local.tee(fold_adr) # Store address, but also keep it for loading the length + g.i32.load() # Load the length + g.local.set(fold_len) # Store the length + + # Stack: [fn*, b] -> [fn*] + g.local.set(fold_init) + # Stack: [fn*] -> [] + g.local.set(fold_func) + + # Stack: [] -> [b] + g.nop(comment='No applications if array is empty') + g.local.get(fold_init) + g.local.get(fold_len) + g.i32.eqz() # If the array is empty + g.br_if(0) # Then the base value is the result + + # Stack: [b] -> [b] ; fold_adr=fold_adr + 4 + g.nop(comment='Skip the header') + g.local.get(fold_adr) + g.i32.const(4) + g.i32.add() + g.local.set(fold_adr) + + # Stack: [b] -> [b] + g.nop(comment='Apply the first function call') + g.local.get(fold_adr) + g.add_statement(sa_type_info.wasm_load_func) + g.local.get(fold_func) + g.call_indirect([res_type_info.wasm_type, sa_type_info.wasm_type], res_type_info.wasm_type) + + # Stack: [b] -> [b] + g.nop(comment='No loop if there is only one item') + g.local.get(fold_len) + g.i32.const(1) + g.i32.eq() + g.br_if(0) # just one value, don't need to loop + + # Stack: [b] -> [b] ; fold_stop=fold_adr + (sa_len.value * sa_type_info.alloc_size) + g.nop(comment='Calculate address at which to stop looping') + g.local.get(fold_adr) + g.local.get(fold_len) + g.i32.const(sa_type_info.alloc_size) + g.i32.mul() + g.i32.add() + g.local.set(fold_stop) + + # Stack: [b] -> [b] ; fold_adr = fold_adr + sa_type_info.alloc_size + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_info.alloc_size) + g.i32.add() + g.local.set(fold_adr) + + with g.loop(params=[res_type_info.wasm_type], result=res_type_info.wasm_type): + # Stack: [b] -> [b] + g.nop(comment='Apply function call') + g.local.get(fold_adr) + g.add_statement(sa_type_info.wasm_load_func) + g.local.get(fold_func) + g.call_indirect([res_type_info.wasm_type, sa_type_info.wasm_type], res_type_info.wasm_type) + + # Stack: [b] -> [b] ; fold_adr = fold_adr + sa_type_info.alloc_size + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_info.alloc_size) + g.i32.add() + g.local.tee(fold_adr) + + # loop if adr > stop + # Stack: [b] -> [b] + g.nop(comment='Check if address exceeds array bounds') + g.local.get(fold_stop) + g.i32.lt_u() + g.br_if(0) + + # Stack: [b] + +def static_array_foldl(g: Generator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + sa_len = tvn_map['a*'] + res_type = tvn_map['b'] + + assert isinstance(sa_type, Type3) + assert isinstance(sa_len, IntType3) + assert isinstance(res_type, Type3) + + sa_type_info = TYPE_INFO_MAP.get(sa_type.name, TYPE_INFO_CONSTRUCTED) + res_type_info = TYPE_INFO_MAP.get(res_type.name, TYPE_INFO_CONSTRUCTED) + + # Definitions + fold_adr = g.temp_var(i32('fold_adr')) + fold_stop = g.temp_var(i32('fold_stop')) + fold_init = g.temp_var_t(res_type_info.wasm_type, 'fold_init') + fold_func = g.temp_var(i32('fold_func')) + + with g.block(params=['i32', res_type_info.wasm_type, 'i32'], result=res_type_info.wasm_type, comment=f'foldl a={sa_type.name} a*={sa_len.value} b={res_type.name}'): + # Stack: [fn*, b, sa*] -> [fn*, b] + g.local.set(fold_adr) + # Stack: [fn*, b] -> [fn*] + g.local.set(fold_init) + # Stack: [fn*] -> [] + g.local.set(fold_func) + + if sa_len.value < 1: + g.local.get(fold_init) + return + + # Stack: [] -> [b] + g.nop(comment='Apply the first function call') + g.local.get(fold_init) + g.local.get(fold_adr) + g.add_statement(sa_type_info.wasm_load_func) + g.local.get(fold_func) + g.call_indirect([res_type_info.wasm_type, sa_type_info.wasm_type], res_type_info.wasm_type) + + if sa_len.value > 1: + # Stack: [b] -> [b] ; fold_stop=fold_adr + (sa_len.value * sa_type_info.alloc_size) + g.nop(comment='Calculate address at which to stop looping') + g.local.get(fold_adr) + g.i32.const(sa_len.value * sa_type_info.alloc_size) + g.i32.add() + g.local.set(fold_stop) + + # Stack: [b] -> [b] ; fold_adr = fold_adr + sa_type_info.alloc_size + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_info.alloc_size) + g.i32.add() + g.local.set(fold_adr) + + with g.loop(params=[res_type_info.wasm_type], result=res_type_info.wasm_type): + # Stack: [b] -> [b] + g.nop(comment='Apply function call') + g.local.get(fold_adr) + g.add_statement(sa_type_info.wasm_load_func) + g.local.get(fold_func) + g.call_indirect([res_type_info.wasm_type, sa_type_info.wasm_type], res_type_info.wasm_type) + + # Stack: [b] -> [b] ; fold_adr = fold_adr + sa_type_info.alloc_size + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_info.alloc_size) + g.i32.add() + g.local.tee(fold_adr) + + # loop if adr > stop + # Stack: [b] -> [b] + g.nop(comment='Check if address exceeds array bounds') + g.local.get(fold_stop) + g.i32.lt_u() + g.br_if(0) + # else: just one value, don't need to loop + + # Stack: [b] + +def dynamic_array_foldr(g: Generator, tvl: TypeVariableLookup) -> None: + raise NotImplementedError('dynamic_array_foldr') + +def static_array_foldr(g: Generator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + sa_len = tvn_map['a*'] + res_type = tvn_map['b'] + + assert isinstance(sa_type, Type3) + assert isinstance(sa_len, IntType3) + assert isinstance(res_type, Type3) + + sa_type_info = TYPE_INFO_MAP.get(sa_type.name, TYPE_INFO_CONSTRUCTED) + res_type_info = TYPE_INFO_MAP.get(res_type.name, TYPE_INFO_CONSTRUCTED) + + # Definitions + fold_adr = g.temp_var(i32('fold_adr')) + fold_stop = g.temp_var(i32('fold_stop')) + fold_tmp = g.temp_var_t(res_type_info.wasm_type, 'fold_tmp') + fold_func = g.temp_var(i32('fold_func')) + + with g.block(params=['i32', res_type_info.wasm_type, 'i32'], result=res_type_info.wasm_type, comment=f'foldr a={sa_type.name} a*={sa_len.value} b={res_type.name}'): + # Stack: [fn*, b, sa*] -> [fn*, b] ; fold_adr=fn*, fold_tmp=b, fold_func=fn* + g.local.set(fold_adr) + # Stack: [fn*, b] -> [fn*] + g.local.set(fold_tmp) + # Stack: [fn*] -> [] + g.local.set(fold_func) + + if sa_len.value < 1: + g.local.get(fold_tmp) + return + + # Stack: [] -> [] ; fold_stop=fold_adr + g.nop(comment='Calculate address at which to stop looping') + g.local.get(fold_adr) + g.local.set(fold_stop) + + # Stack: [] -> [] ; fold_adr=fold_adr + (sa_len.value - 1) * sa_type_info.alloc_size + g.nop(comment='Calculate address at which to stop looping') + g.local.get(fold_adr) + g.i32.const((sa_len.value - 1) * sa_type_info.alloc_size) + g.i32.add() + g.local.set(fold_adr) + + # Stack: [] -> [b] + g.nop(comment='Get the init value and first array value as starting point') + g.local.get(fold_adr) + g.add_statement(sa_type_info.wasm_load_func) + g.local.get(fold_tmp) + g.local.get(fold_func) + g.call_indirect([sa_type_info.wasm_type, res_type_info.wasm_type], res_type_info.wasm_type) + + if sa_len.value > 1: + # Stack: [b] -> [b] ; fold_adr = fold_adr - sa_type_info.alloc_size + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_info.alloc_size) + g.i32.sub() + g.local.set(fold_adr) + + with g.loop(params=[res_type_info.wasm_type], result=res_type_info.wasm_type): + g.nop(comment='Apply function call') + + # Stack [b] since we don't have proper stack switching opcodes + # Stack: [b] -> [] + g.local.set(fold_tmp) + + # Stack: [] -> [a] + g.local.get(fold_adr) + g.add_statement(sa_type_info.wasm_load_func) + + # Stack [a] -> [a, b] + g.local.get(fold_tmp) + + # Stack [a, b] -> [b] + g.local.get(fold_func) + g.call_indirect([sa_type_info.wasm_type, res_type_info.wasm_type], res_type_info.wasm_type) + + # Stack: [b] -> [b] ; fold_adr = fold_adr - sa_type_info.alloc_size + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_info.alloc_size) + g.i32.sub() + g.local.tee(fold_adr) + + # loop if adr >= stop + # Stack: [b] -> [b] + g.nop(comment='Check if address exceeds array bounds') + g.local.get(fold_stop) + g.i32.ge_u() + g.br_if(0) + # else: just one value, don't need to loop + + # Stack: [b] diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index 6046724..3ed8ccf 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -183,7 +183,7 @@ class SameTypeArgumentConstraint(ConstraintBase): self.arg_var = arg_var def check(self) -> CheckResult: - if self.tc_var.resolve_as is None or self.arg_var.resolve_as is None: + if self.tc_var.resolve_as is None: return RequireTypeSubstitutes() tc_typ = self.tc_var.resolve_as @@ -200,12 +200,16 @@ class SameTypeArgumentConstraint(ConstraintBase): # So we can let the MustImplementTypeClassConstraint handle it. return None + if isinstance(tc_typ.application, TypeApplication_Type): + return [SameTypeConstraint( + tc_typ.application.arguments[0], + self.arg_var, + comment=self.comment, + )] + # FIXME: This feels sketchy. Shouldn't the type variable # have the exact same number as arguments? if isinstance(tc_typ.application, TypeApplication_TypeInt): - if tc_typ.application.arguments[0] == arg_typ: - return None - return [SameTypeConstraint( tc_typ.application.arguments[0], self.arg_var, @@ -346,14 +350,10 @@ class MustImplementTypeClassConstraint(ConstraintBase): __slots__ = ('context', 'type_class3', 'types', ) context: Context - type_class3: Union[str, Type3Class] + type_class3: Type3Class types: list[Type3OrPlaceholder] - DATA = { - 'dynamic_array': {'Foldable'}, - } - - def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None: + def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None: super().__init__(comment=comment) self.context = context @@ -381,13 +381,9 @@ class MustImplementTypeClassConstraint(ConstraintBase): assert len(typ_list) == len(self.types) - if isinstance(self.type_class3, Type3Class): - key = (self.type_class3, tuple(typ_list), ) - if key in self.context.type_class_instances_existing: - return None - else: - if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()): - return None + key = (self.type_class3, tuple(typ_list), ) + if key in self.context.type_class_instances_existing: + return None typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name typ_name_list = ' '.join(x.name for x in typ_list) diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index b0183d3..f28b88f 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -202,6 +202,10 @@ def _expression_function_call( yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment) continue + if isinstance(sig_part, FunctionArgument): + yield SameTypeConstraint(func_var_map[sig_part], arg_placeholders[arg_expr], comment=comment) + continue + raise NotImplementedError(sig_part) return @@ -265,19 +269,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') return - if isinstance(inp, ourlang.Fold): - base_phft = PlaceholderForType([inp.base]) - iter_phft = PlaceholderForType([inp.iter]) - - yield from expression(ctx, inp.base, base_phft) - yield from expression(ctx, inp.iter, iter_phft) - - yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft, - comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b') - yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft]) - - return - raise NotImplementedError(expression, inp) def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: diff --git a/phasm/type3/routers.py b/phasm/type3/routers.py index 90b8fa4..26fa6cc 100644 --- a/phasm/type3/routers.py +++ b/phasm/type3/routers.py @@ -3,6 +3,7 @@ from typing import Any, Callable from .functions import ( TypeConstructorVariable, TypeVariable, + TypeVariableApplication_Nullary, TypeVariableApplication_Unary, ) from .typeclasses import Type3ClassArgs @@ -60,7 +61,10 @@ class TypeApplicationRouter[S, R]: raise NoRouteForTypeException(arg0, typ) -TypeVariableLookup = dict[TypeVariable, tuple[KindArgument, ...]] +TypeVariableLookup = tuple[ + dict[TypeVariable, KindArgument], + dict[TypeConstructorVariable, TypeConstructor_Base[Any]], +] class TypeClassArgsRouter[S, R]: """ @@ -95,11 +99,12 @@ class TypeClassArgsRouter[S, R]: def __call__(self, arg0: S, tv_map: dict[TypeVariable, Type3]) -> R: key: list[Type3 | TypeConstructor_Base[Any]] = [] - arguments: TypeVariableLookup = {} + arguments: TypeVariableLookup = (dict(tv_map), {}, ) for tc_arg in self.args: if isinstance(tc_arg, TypeVariable): key.append(tv_map[tc_arg]) + arguments[0][tc_arg] = tv_map[tc_arg] continue for tvar, typ in tv_map.items(): @@ -108,16 +113,24 @@ class TypeClassArgsRouter[S, R]: continue key.append(typ.application.constructor) + arguments[1][tc_arg] = typ.application.constructor if isinstance(tvar.application, TypeVariableApplication_Unary): if isinstance(typ.application, TypeApplication_Type): - arguments[tvar.application.arguments] = typ.application.arguments + da_type, = typ.application.arguments + sa_type_tv = tvar.application.arguments + arguments[0][sa_type_tv] = da_type continue # FIXME: This feels sketchy. Shouldn't the type variable # have the exact same number as arguments? if isinstance(typ.application, TypeApplication_TypeInt): - arguments[tvar.application.arguments] = typ.application.arguments + sa_type, sa_len = typ.application.arguments + sa_type_tv = tvar.application.arguments + sa_len_tv = TypeVariable(sa_type_tv.name + '*', TypeVariableApplication_Nullary(None, None)) + + arguments[0][sa_type_tv] = sa_type + arguments[0][sa_len_tv] = sa_len continue raise NotImplementedError(tvar.application, typ.application) diff --git a/phasm/type3/typeclasses.py b/phasm/type3/typeclasses.py index 83d87be..02c6ad0 100644 --- a/phasm/type3/typeclasses.py +++ b/phasm/type3/typeclasses.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, Mapping, Optional, Union +from typing import Dict, Iterable, List, Mapping, Optional from .functions import ( Constraint_TypeClassInstanceExists, @@ -42,8 +42,8 @@ class Type3Class: self, name: str, args: Type3ClassArgs, - methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]], - operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]], + methods: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]], + operators: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]], inherited_classes: Optional[List['Type3Class']] = None, additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None, ) -> None: @@ -71,19 +71,23 @@ class Type3Class: return self.name def _create_signature( - method_arg_list: Iterable[Type3 | TypeVariable], + method_arg_list: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]], type_class3: Type3Class, ) -> FunctionSignature: context = TypeVariableContext() if not isinstance(type_class3.args[0], TypeConstructorVariable): context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args)) - signature_args: list[Type3 | TypeVariable] = [] + signature_args: list[Type3 | TypeVariable | list[Type3 | TypeVariable]] = [] for method_arg in method_arg_list: if isinstance(method_arg, Type3): signature_args.append(method_arg) continue + if isinstance(method_arg, list): + signature_args.append(method_arg) + continue + if isinstance(method_arg, TypeVariable): type_constructor = method_arg.application.constructor if type_constructor is None: diff --git a/phasm/wasmgenerator.py b/phasm/wasmgenerator.py index 9bee38e..2b6ead0 100644 --- a/phasm/wasmgenerator.py +++ b/phasm/wasmgenerator.py @@ -170,23 +170,32 @@ class Generator_Local: self.generator.add_statement('local.tee', variable.name_ref, comment=comment) class GeneratorBlock: - def __init__(self, generator: 'Generator', name: str, params: Iterable[str] = (), result: str | None = None) -> None: + def __init__( + self, + generator: 'Generator', + name: str, + params: Iterable[str | Type[wasm.WasmType]] = (), + result: str | Type[wasm.WasmType] | None = None, + comment: str | None = None, + ) -> None: self.generator = generator self.name = name self.params = params self.result = result + self.comment = comment def __enter__(self) -> None: stmt = self.name if self.params: stmt = f'{stmt} ' + ' '.join( - f'(param {typ})' + f'(param {typ})' if isinstance(typ, str) else f'(param {typ().to_wat()})' for typ in self.params ) if self.result: - stmt = f'{stmt} (result {self.result})' + result = self.result if isinstance(self.result, str) else self.result().to_wat() + stmt = f'{stmt} (result {result})' - self.generator.add_statement(stmt) + self.generator.add_statement(stmt, comment=self.comment) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if not exc_type: @@ -208,7 +217,7 @@ class Generator: # 2.4.5 Control Instructions self.nop = functools.partial(self.add_statement, 'nop') self.unreachable = functools.partial(self.add_statement, 'unreachable') - # block + self.block = functools.partial(GeneratorBlock, self, 'block') self.loop = functools.partial(GeneratorBlock, self, 'loop') self.if_ = functools.partial(GeneratorBlock, self, 'if') # br @@ -224,6 +233,16 @@ class Generator: def call(self, function: wasm.Function) -> None: self.add_statement('call', f'${function.name}') + def call_indirect(self, params: Iterable[Type[wasm.WasmType]], result: Type[wasm.WasmType]) -> None: + param_str = ' '.join( + x().to_wat() + for x in params + ) + + result_str = result().to_wat() + + self.add_statement(f'call_indirect (param {param_str}) (result {result_str})') + def add_statement(self, name: str, *args: str, comment: Optional[str] = None) -> None: self.statements.append(wasm.Statement(name, *args, comment=comment)) @@ -234,6 +253,28 @@ class Generator: return var.__class__(varname) + def temp_var_t(self, typ: Type[wasm.WasmType], name: str) -> VarType_Base: + idx = 0 + while (varname := f'__{name}_tmp_var_{idx}__') in self.locals: + idx += 1 + + if typ is wasm.WasmTypeInt32: + return VarType_u8(varname) + + if typ is wasm.WasmTypeInt32: + return VarType_i32(varname) + + if typ is wasm.WasmTypeInt64: + return VarType_i64(varname) + + if typ is wasm.WasmTypeFloat32: + return VarType_f32(varname) + + if typ is wasm.WasmTypeFloat64: + return VarType_f64(varname) + + raise NotImplementedError(typ) + def temp_var_i32(self, infix: str) -> VarType_i32: return self.temp_var(VarType_i32(infix)) diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index dad1009..e9b5338 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -1,3 +1,4 @@ +import os import struct import sys from typing import Any, Generator, Iterable, List, TextIO, Union @@ -36,26 +37,39 @@ class Suite: def __init__(self, code_py: str) -> None: self.code_py = code_py - def run_code(self, *args: Any, runtime: str = 'wasmtime', func_name: str = 'testEntry', imports: runners.Imports = None, do_format_check: bool = True) -> Any: + def run_code( + self, + *args: Any, + runtime: str = 'wasmtime', + func_name: str = 'testEntry', + imports: runners.Imports = None, + do_format_check: bool = True, + verbose: bool | None = None, + ) -> Any: """ Compiles the given python code into wasm and then runs it Returned is an object with the results set """ + if verbose is None: + verbose = bool(os.environ.get('VERBOSE')) + class_ = RUNNER_CLASS_MAP[runtime] runner = class_(self.code_py) - write_header(sys.stderr, 'Phasm') - runner.dump_phasm_code(sys.stderr) + if verbose: + write_header(sys.stderr, 'Phasm') + runner.dump_phasm_code(sys.stderr) - runner.parse() + runner.parse(verbose=verbose) runner.compile_ast() runner.compile_wat() - write_header(sys.stderr, 'Assembly') - runner.dump_wasm_wat(sys.stderr) + if verbose: + write_header(sys.stderr, 'Assembly') + runner.dump_wasm_wat(sys.stderr) runner.interpreter_setup() runner.interpreter_load(imports) @@ -70,8 +84,9 @@ class Suite: wasm_args: List[Union[float, int]] = [] if args: - write_header(sys.stderr, 'Memory (pre alloc)') - runner.interpreter_dump_memory(sys.stderr) + if verbose: + write_header(sys.stderr, 'Memory (pre alloc)') + runner.interpreter_dump_memory(sys.stderr) for arg, arg_typ in zip(args, func_args, strict=True): if arg_typ in (prelude.u8, prelude.u32, prelude.u64, ): @@ -95,8 +110,9 @@ class Suite: except NoRouteForTypeException: raise NotImplementedError(arg_typ, arg) - write_header(sys.stderr, 'Memory (pre run)') - runner.interpreter_dump_memory(sys.stderr) + if verbose: + write_header(sys.stderr, 'Memory (pre run)') + runner.interpreter_dump_memory(sys.stderr) result = SuiteResult() result.returned_value = runner.call(func_name, *wasm_args) @@ -107,8 +123,9 @@ class Suite: result.returned_value, ) - write_header(sys.stderr, 'Memory (post run)') - runner.interpreter_dump_memory(sys.stderr) + if verbose: + write_header(sys.stderr, 'Memory (post run)') + runner.interpreter_dump_memory(sys.stderr) return result diff --git a/tests/integration/runners.py b/tests/integration/runners.py index 6afeacf..3957233 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -32,12 +32,12 @@ class RunnerBase: """ _dump_code(textio, self.phasm_code) - def parse(self) -> None: + def parse(self, verbose: bool = True) -> None: """ Parses the Phasm code into an AST """ self.phasm_ast = phasm_parse(self.phasm_code) - phasm_type3(self.phasm_ast, verbose=True) + phasm_type3(self.phasm_ast, verbose=verbose) def compile_ast(self) -> None: """ @@ -120,6 +120,8 @@ class RunnerWasmtime(RunnerBase): if vartype is int: params.append(wasmtime.ValType.i32()) + elif vartype is float: + params.append(wasmtime.ValType.f32()) else: raise NotImplementedError @@ -128,6 +130,8 @@ class RunnerWasmtime(RunnerBase): pass # No return value elif func.__annotations__['return'] is int: results.append(wasmtime.ValType.i32()) + elif func.__annotations__['return'] is float: + results.append(wasmtime.ValType.f32()) else: raise NotImplementedError('Return type', func.__annotations__['return']) diff --git a/tests/integration/test_lang/test_builtins.py b/tests/integration/test_lang/test_builtins.py deleted file mode 100644 index 13400fd..0000000 --- a/tests/integration/test_lang/test_builtins.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from ..helpers import Suite - - -@pytest.mark.integration_test -def test_foldl_1(): - code_py = """ -def u8_or(l: u8, r: u8) -> u8: - return l | r - -@exported -def testEntry(b: bytes) -> u8: - return foldl(u8_or, 128, b) -""" - suite = Suite(code_py) - - result = suite.run_code(b'') - assert 128 == result.returned_value - - result = suite.run_code(b'\x80') - assert 128 == result.returned_value - - result = suite.run_code(b'\x80\x40') - assert 192 == result.returned_value - - result = suite.run_code(b'\x80\x40\x20\x10') - assert 240 == result.returned_value - - result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01') - assert 255 == result.returned_value - -@pytest.mark.integration_test -def test_foldl_2(): - code_py = """ -def xor(l: u8, r: u8) -> u8: - return l ^ r - -@exported -def testEntry(a: bytes, b: bytes) -> u8: - return foldl(xor, 0, a) ^ foldl(xor, 0, b) -""" - suite = Suite(code_py) - - result = suite.run_code(b'\x55\x0F', b'\x33\x80') - assert 233 == result.returned_value - -@pytest.mark.integration_test -def test_foldl_3(): - code_py = """ -def xor(l: u32, r: u8) -> u32: - return l ^ extend(r) - -@exported -def testEntry(a: bytes) -> u32: - return foldl(xor, 0, a) -""" - suite = Suite(code_py) - - result = suite.run_code(b'\x55\x0F\x33\x80') - assert 233 == result.returned_value diff --git a/tests/integration/test_typeclasses/test_foldable.py b/tests/integration/test_typeclasses/test_foldable.py index 34320b0..c4e1c53 100644 --- a/tests/integration/test_typeclasses/test_foldable.py +++ b/tests/integration/test_typeclasses/test_foldable.py @@ -9,10 +9,13 @@ from .test_natnum import FLOAT_TYPES, INT_TYPES @pytest.mark.integration_test @pytest.mark.parametrize('length', [1, 5, 13]) @pytest.mark.parametrize('a_type', INT_TYPES + FLOAT_TYPES) -def test_foldable_sum(length, a_type): +@pytest.mark.parametrize('static', [True, False]) +def test_foldable_sum(length, a_type, static): + typ_arg = str(length) if static else '...' + code_py = f""" @exported -def testEntry(x: {a_type}[{length}]) -> {a_type}: +def testEntry(x: {a_type}[{typ_arg}]) -> {a_type}: return sum(x) """ @@ -36,6 +39,145 @@ def testEntry(x: Foo[4]) -> Foo: with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'): Suite(code_py).run_code() +@pytest.mark.integration_test +@pytest.mark.parametrize('length', [1, 5, 13]) +@pytest.mark.parametrize('direction', ['foldl', 'foldr']) +def test_foldable_foldl_foldr_size(direction, length): + code_py = f""" +def u64_add(l: u64, r: u64) -> u64: + return l + r + +@exported +def testEntry(b: u64[{length}]) -> u64: + return {direction}(u64_add, 100, b) +""" + suite = Suite(code_py) + + in_put = tuple(range(1, length + 1)) + + result = suite.run_code(in_put) + assert (100 + sum(in_put)) == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('direction', ['foldl', 'foldr']) +def test_foldable_foldl_foldr_compounded_type(direction): + code_py = f""" +def combine_foldl(b: u64, a: (u32, u32, )) -> u64: + return extend(a[0] * a[1]) + b + +def combine_foldr(a: (u32, u32, ), b: u64) -> u64: + return extend(a[0] * a[1]) + b + +@exported +def testEntry(b: (u32, u32, )[3]) -> u64: + return {direction}(combine_{direction}, 10000, b) +""" + suite = Suite(code_py) + + result = suite.run_code(((2, 5), (25, 4), (125, 8))) + assert 11110 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('in_put, direction, exp_result', [ + ([], 'foldl', 0, ), + ([], 'foldr', 0, ), + + ([1], 'foldl', -1, ), + ([1], 'foldr', 1, ), + + ([1,2], 'foldl', -3, ), + ([1,2], 'foldr', -1, ), + + ([1,2,3], 'foldl', -6, ), + ([1,2,3], 'foldr', 2, ), + + ([1,2,3,4], 'foldl', -10, ), + ([1,2,3,4], 'foldr', -2, ), + + ([1,2,3,4,5], 'foldl', -15, ), + ([1,2,3,4,5], 'foldr', 3, ), + + ([1,2,3,4,5,6], 'foldl', -21, ), + ([1,2,3,4,5,6], 'foldr', -3, ), + + ([1,2,3,4,5,6,7], 'foldl', -28, ), + ([1,2,3,4,5,6,7], 'foldr', 4, ), + + ([1,2,3,4,5,6,7,8], 'foldl', -36, ), + ([1,2,3,4,5,6,7,8], 'foldr', -4, ), +]) +@pytest.mark.parametrize('static', [True, False]) +def test_foldable_foldl_foldr_result(direction, in_put, exp_result, static): + typ_arg = str(len(in_put)) if static else '...' + + # See https://stackoverflow.com/a/13280185 + code_py = f""" +def i32_sub(l: i32, r: i32) -> i32: + return l - r + +@exported +def testEntry(b: i32[{typ_arg}]) -> i32: + return {direction}(i32_sub, 0, b) +""" + suite = Suite(code_py) + + result = suite.run_code(tuple(in_put)) + assert exp_result == result.returned_value + +@pytest.mark.integration_test +def test_foldable_foldl_bytes(): + code_py = """ +def u8_or(l: u8, r: u8) -> u8: + return l | r + +@exported +def testEntry(b: bytes) -> u8: + return foldl(u8_or, 0, b) +""" + suite = Suite(code_py) + + result = suite.run_code(b'') + assert 0 == result.returned_value + + result = suite.run_code(b'\x80') + assert 128 == result.returned_value + + result = suite.run_code(b'\x80\x40') + assert 192 == result.returned_value + + result = suite.run_code(b'\x80\x40\x20\x10') + assert 240 == result.returned_value + + result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01') + assert 255 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('in_typ', ['i8', 'i8[3]']) +def test_foldable_argument_must_be_a_function(in_typ): + code_py = f""" +@exported +def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32: + return foldl(x, y, z) +""" + + r_in_typ = in_typ.replace('[', '\\[').replace(']', '\\]') + + with pytest.raises(Type3Exception, match=f'{r_in_typ} must be a function instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_foldable_argument_must_be_right_function(): + code_py = """ +def foo(l: i32, r: i64) -> i64: + return extend(l) + r + +@exported +def testEntry(i: i64, l: i64[3]) -> i64: + return foldr(foo, i, l) +""" + + with pytest.raises(Type3Exception, match=r'Callable\[i64, i64, i64\] must be Callable\[i32, i64, i64\] instead'): + Suite(code_py).run_code() @pytest.mark.integration_test def test_foldable_invalid_return_type():