diff --git a/TODO.md b/TODO.md index e5cc852..2cc5986 100644 --- a/TODO.md +++ b/TODO.md @@ -30,3 +30,5 @@ - Functions don't seem to be a thing on typing level yet? - Related to the FIXME in phasm_type3? - Type constuctor should also be able to constuct placeholders - somehow. + +- Read https://bytecodealliance.org/articles/multi-value-all-the-wasm diff --git a/phasm/codestyle.py b/phasm/codestyle.py index fc40868..940ecfa 100644 --- a/phasm/codestyle.py +++ b/phasm/codestyle.py @@ -85,6 +85,9 @@ def expression(inp: ourlang.Expression) -> str: return f'{inp.function.name}({args})' + if isinstance(inp, ourlang.FunctionReference): + return str(inp.function.name) + if isinstance(inp, ourlang.TupleInstantiation): args = ', '.join( expression(arg) @@ -102,10 +105,6 @@ def expression(inp: ourlang.Expression) -> str: if isinstance(inp, ourlang.AccessStructMember): return f'{expression(inp.varref)}.{inp.member}' - if isinstance(inp, ourlang.Fold): - fold_name = 'foldl' if ourlang.Fold.Direction.LEFT == inp.dir else 'foldr' - return f'{fold_name}({inp.func.name}, {expression(inp.base)}, {expression(inp.iter)})' - raise NotImplementedError(expression, inp) def statement(inp: ourlang.Statement) -> Statements: diff --git a/phasm/compiler.py b/phasm/compiler.py index d098499..a199a41 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -4,11 +4,11 @@ This module contains the code to convert parsed Ourlang into WebAssembly code import struct from typing import List, Optional -from . import codestyle, ourlang, prelude, wasm +from . import ourlang, prelude, wasm from .runtime import calculate_alloc_size, calculate_member_offset from .stdlib import alloc as stdlib_alloc from .stdlib import types as stdlib_types -from .type3.functions import TypeVariable +from .type3.functions import FunctionArgument, TypeVariable from .type3.routers import NoRouteForTypeException, TypeApplicationRouter from .type3.typeclasses import Type3ClassMethod from .type3.types import ( @@ -100,7 +100,7 @@ def type3(inp: Type3) -> wasm.WasmType: raise NotImplementedError(type3, inp) -def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> None: +def tuple_instantiation(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.TupleInstantiation) -> None: """ Compile: Instantiation (allocation) of a tuple """ @@ -150,7 +150,7 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> wgn.add_statement('nop', comment='PRE') wgn.local.get(tmp_var) - expression(wgn, element) + expression(wgn, mod, element) wgn.add_statement(f'{mtyp}.store', 'offset=' + str(offset)) wgn.add_statement('nop', comment='POST') @@ -160,29 +160,29 @@ def tuple_instantiation(wgn: WasmGenerator, inp: ourlang.TupleInstantiation) -> wgn.local.get(tmp_var) def expression_subscript_bytes( - attrs: tuple[WasmGenerator, ourlang.Subscript], + attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], ) -> None: - wgn, inp = attrs + wgn, mod, inp = attrs - expression(wgn, inp.varref) - expression(wgn, inp.index) + expression(wgn, mod, inp.varref) + expression(wgn, mod, inp.index) wgn.call(stdlib_types.__subscript_bytes__) def expression_subscript_static_array( - attrs: tuple[WasmGenerator, ourlang.Subscript], + attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], args: tuple[Type3, IntType3], ) -> None: - wgn, inp = attrs + wgn, mod, inp = attrs el_type, el_len = args # OPTIMIZE: If index is a constant, we can use offset instead of multiply # and we don't need to do the out of bounds check - expression(wgn, inp.varref) + expression(wgn, mod, inp.varref) tmp_var = wgn.temp_var_i32('index') - expression(wgn, inp.index) + expression(wgn, mod, inp.index) wgn.local.tee(tmp_var) # Out of bounds check based on el_len.value @@ -201,10 +201,10 @@ def expression_subscript_static_array( wgn.add_statement(f'{mtyp}.load') def expression_subscript_tuple( - attrs: tuple[WasmGenerator, ourlang.Subscript], + attrs: tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], args: tuple[Type3, ...], ) -> None: - wgn, inp = attrs + wgn, mod, inp = attrs assert isinstance(inp.index, ourlang.ConstantPrimitive) assert isinstance(inp.index.value, int) @@ -217,7 +217,7 @@ def expression_subscript_tuple( el_type = args[inp.index.value] assert el_type is not None, TYPE3_ASSERTION_ERROR - expression(wgn, inp.varref) + expression(wgn, mod, inp.varref) if (prelude.InternalPassAsPointer, (el_type, )) in prelude.PRELUDE_TYPE_CLASS_INSTANCES_EXISTING: mtyp = 'i32' @@ -226,12 +226,12 @@ def expression_subscript_tuple( wgn.add_statement(f'{mtyp}.load', f'offset={offset}') -SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Subscript], None]() +SUBSCRIPT_ROUTER = TypeApplicationRouter[tuple[WasmGenerator, ourlang.Module, ourlang.Subscript], None]() SUBSCRIPT_ROUTER.add_n(prelude.bytes_, expression_subscript_bytes) SUBSCRIPT_ROUTER.add(prelude.static_array, expression_subscript_static_array) SUBSCRIPT_ROUTER.add(prelude.tuple_, expression_subscript_tuple) -def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: +def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) -> None: """ Compile: Any expression """ @@ -291,14 +291,14 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.i32.const(address) return - expression(wgn, inp.variable.constant) + expression(wgn, mod, inp.variable.constant) return raise NotImplementedError(expression, inp.variable) if isinstance(inp, ourlang.BinaryOp): - expression(wgn, inp.left) - expression(wgn, inp.right) + expression(wgn, mod, inp.left) + expression(wgn, mod, inp.right) type_var_map: dict[TypeVariable, Type3] = {} @@ -313,6 +313,10 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: type_var_map[type_var] = arg_expr.type3 continue + if isinstance(type_var, FunctionArgument): + # Fixed type, not part of the lookup requirements + continue + raise NotImplementedError(type_var, arg_expr.type3) router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.operator] @@ -321,7 +325,7 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: if isinstance(inp, ourlang.FunctionCall): for arg in inp.arguments: - expression(wgn, arg) + expression(wgn, mod, arg) if isinstance(inp.function, Type3ClassMethod): # FIXME: Duplicate code with BinaryOp @@ -338,6 +342,10 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: type_var_map[type_var] = arg_expr.type3 continue + if isinstance(type_var, FunctionArgument): + # Fixed type, not part of the lookup requirements + continue + raise NotImplementedError(type_var, arg_expr.type3) router = prelude.PRELUDE_TYPE_CLASS_INSTANCE_METHODS[inp.function] @@ -350,15 +358,24 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: wgn.add_statement('call', '${}'.format(inp.function.name)) return + if isinstance(inp, ourlang.FunctionReference): + idx = mod.functions_table.get(inp.function) + if idx is None: + idx = len(mod.functions_table) + mod.functions_table[inp.function] = idx + + wgn.add_statement('i32.const', str(idx), comment=inp.function.name) + return + if isinstance(inp, ourlang.TupleInstantiation): - tuple_instantiation(wgn, inp) + tuple_instantiation(wgn, mod, inp) return if isinstance(inp, ourlang.Subscript): assert inp.varref.type3 is not None, TYPE3_ASSERTION_ERROR # Type checker guarantees we don't get routing errors - SUBSCRIPT_ROUTER((wgn, inp, ), inp.varref.type3) + SUBSCRIPT_ROUTER((wgn, mod, inp, ), inp.varref.type3) return if isinstance(inp, ourlang.AccessStructMember): @@ -370,111 +387,29 @@ def expression(wgn: WasmGenerator, inp: ourlang.Expression) -> None: mtyp = LOAD_STORE_TYPE_MAP[member_type.name] - expression(wgn, inp.varref) + expression(wgn, mod, inp.varref) wgn.add_statement(f'{mtyp}.load', 'offset=' + str(calculate_member_offset( inp.struct_type3.name, inp.struct_type3.application.arguments, inp.member ))) return - if isinstance(inp, ourlang.Fold): - expression_fold(wgn, inp) - return - raise NotImplementedError(expression, inp) -def expression_fold(wgn: WasmGenerator, inp: ourlang.Fold) -> None: - """ - Compile: Fold expression - """ - assert inp.type3 is not None, TYPE3_ASSERTION_ERROR - - if inp.iter.type3 is not prelude.bytes_: - raise NotImplementedError(expression_fold, inp, inp.iter.type3) - - wgn.add_statement('nop', comment='acu :: u8') - acu_var = wgn.temp_var_u8(f'fold_{codestyle.type3(inp.type3)}_acu') - wgn.add_statement('nop', comment='adr :: bytes*') - adr_var = wgn.temp_var_i32('fold_i32_adr') - wgn.add_statement('nop', comment='len :: i32') - len_var = wgn.temp_var_i32('fold_i32_len') - - wgn.add_statement('nop', comment='acu = base') - expression(wgn, inp.base) - wgn.local.set(acu_var) - - wgn.add_statement('nop', comment='adr = adr(iter)') - expression(wgn, inp.iter) - wgn.local.set(adr_var) - - wgn.add_statement('nop', comment='len = len(iter)') - wgn.local.get(adr_var) - wgn.i32.load() - wgn.local.set(len_var) - - wgn.add_statement('nop', comment='i = 0') - idx_var = wgn.temp_var_i32(f'fold_{codestyle.type3(inp.type3)}_idx') - wgn.i32.const(0) - wgn.local.set(idx_var) - - wgn.add_statement('nop', comment='if i < len') - wgn.local.get(idx_var) - wgn.local.get(len_var) - wgn.i32.lt_u() - with wgn.if_(): - # From here on, adr_var is the address of byte we're referencing - # This is akin to calling stdlib_types.__subscript_bytes__ - # But since we already know we are inside of bounds, - # can just bypass it and load the memory directly. - wgn.local.get(adr_var) - wgn.i32.const(3) # Bytes header -1, since we do a +1 every loop - wgn.i32.add() - wgn.local.set(adr_var) - - wgn.add_statement('nop', comment='while True') - with wgn.loop(): - wgn.add_statement('nop', comment='acu = func(acu, iter[i])') - wgn.local.get(acu_var) - - # Get the next byte, write back the address - wgn.local.get(adr_var) - wgn.i32.const(1) - wgn.i32.add() - wgn.local.tee(adr_var) - wgn.i32.load8_u() - - wgn.add_statement('call', f'${inp.func.name}') - wgn.local.set(acu_var) - - wgn.add_statement('nop', comment='i = i + 1') - wgn.local.get(idx_var) - wgn.i32.const(1) - wgn.i32.add() - wgn.local.set(idx_var) - - wgn.add_statement('nop', comment='if i >= len: break') - wgn.local.get(idx_var) - wgn.local.get(len_var) - wgn.i32.lt_u() - wgn.br_if(0) - - # return acu - wgn.local.get(acu_var) - -def statement_return(wgn: WasmGenerator, inp: ourlang.StatementReturn) -> None: +def statement_return(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementReturn) -> None: """ Compile: Return statement """ - expression(wgn, inp.value) + expression(wgn, mod, inp.value) wgn.return_() -def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None: +def statement_if(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.StatementIf) -> None: """ Compile: If statement """ - expression(wgn, inp.test) + expression(wgn, mod, inp.test) with wgn.if_(): for stat in inp.statements: - statement(wgn, stat) + statement(wgn, mod, stat) if inp.else_statements: raise NotImplementedError @@ -482,16 +417,16 @@ def statement_if(wgn: WasmGenerator, inp: ourlang.StatementIf) -> None: # for stat in inp.else_statements: # statement(wgn, stat) -def statement(wgn: WasmGenerator, inp: ourlang.Statement) -> None: +def statement(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Statement) -> None: """ Compile: any statement """ if isinstance(inp, ourlang.StatementReturn): - statement_return(wgn, inp) + statement_return(wgn, mod, inp) return if isinstance(inp, ourlang.StatementIf): - statement_if(wgn, inp) + statement_if(wgn, mod, inp) return if isinstance(inp, ourlang.StatementPass): @@ -522,7 +457,7 @@ def import_(inp: ourlang.Function) -> wasm.Import: type3(inp.returns_type3) ) -def function(inp: ourlang.Function) -> wasm.Function: +def function(mod: ourlang.Module, inp: ourlang.Function) -> wasm.Function: """ Compile: function """ @@ -534,7 +469,7 @@ def function(inp: ourlang.Function) -> wasm.Function: _generate_struct_constructor(wgn, inp) else: for stat in inp.statements: - statement(wgn, stat) + statement(wgn, mod, stat) return wasm.Function( inp.name, @@ -724,26 +659,32 @@ def module(inp: ourlang.Module) -> wasm.Module: stdlib_alloc.__find_free_block__, stdlib_alloc.__alloc__, stdlib_types.__alloc_bytes__, - stdlib_types.__subscript_bytes__, - stdlib_types.__u32_ord_min__, - stdlib_types.__u64_ord_min__, - stdlib_types.__i32_ord_min__, - stdlib_types.__i64_ord_min__, - stdlib_types.__u32_ord_max__, - stdlib_types.__u64_ord_max__, - stdlib_types.__i32_ord_max__, - stdlib_types.__i64_ord_max__, - stdlib_types.__i32_intnum_abs__, - stdlib_types.__i64_intnum_abs__, - stdlib_types.__u32_pow2__, - stdlib_types.__u8_rotl__, - stdlib_types.__u8_rotr__, + # stdlib_types.__subscript_bytes__, + # stdlib_types.__u32_ord_min__, + # stdlib_types.__u64_ord_min__, + # stdlib_types.__i32_ord_min__, + # stdlib_types.__i64_ord_min__, + # stdlib_types.__u32_ord_max__, + # stdlib_types.__u64_ord_max__, + # stdlib_types.__i32_ord_max__, + # stdlib_types.__i64_ord_max__, + # stdlib_types.__i32_intnum_abs__, + # stdlib_types.__i64_intnum_abs__, + # stdlib_types.__u32_pow2__, + # stdlib_types.__u8_rotl__, + # stdlib_types.__u8_rotr__, ] + [ - function(x) + function(inp, x) for x in inp.functions.values() if not x.imported ] + # Do this after rendering the functions since that's what populates the tables + result.table = { + v: k.name + for k, v in inp.functions_table.items() + } + return result def _generate_struct_constructor(wgn: WasmGenerator, inp: ourlang.StructConstructor) -> None: diff --git a/phasm/ourlang.py b/phasm/ourlang.py index df97b23..eca02de 100644 --- a/phasm/ourlang.py +++ b/phasm/ourlang.py @@ -1,7 +1,6 @@ """ Contains the syntax tree for ourlang """ -import enum from typing import Dict, Iterable, List, Optional, Union from . import prelude @@ -161,6 +160,18 @@ class FunctionCall(Expression): self.function = function self.arguments = [] +class FunctionReference(Expression): + """ + An function reference expression within a statement + """ + __slots__ = ('function', ) + + function: 'Function' + + def __init__(self, function: 'Function') -> None: + super().__init__() + self.function = function + class TupleInstantiation(Expression): """ Instantiation a tuple @@ -207,36 +218,6 @@ class AccessStructMember(Expression): self.struct_type3 = struct_type3 self.member = member -class Fold(Expression): - """ - A (left or right) fold - """ - class Direction(enum.Enum): - """ - Which direction to fold in - """ - LEFT = 0 - RIGHT = 1 - - dir: Direction - func: 'Function' - base: Expression - iter: Expression - - def __init__( - self, - dir_: Direction, - func: 'Function', - base: Expression, - iter_: Expression, - ) -> None: - super().__init__() - - self.dir = dir_ - self.func = func - self.base = base - self.iter = iter_ - class Statement: """ A statement within a function @@ -397,13 +378,14 @@ class Module: """ A module is a file and consists of functions """ - __slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'operators', ) + __slots__ = ('data', 'types', 'struct_definitions', 'constant_defs', 'functions', 'functions_table', 'operators', ) data: ModuleData types: dict[str, Type3] struct_definitions: Dict[str, StructDefinition] constant_defs: Dict[str, ModuleConstantDef] functions: Dict[str, Function] + functions_table: dict[Function, int] operators: Dict[str, Type3ClassMethod] def __init__(self) -> None: @@ -413,3 +395,4 @@ class Module: self.constant_defs = {} self.functions = {} self.operators = {} + self.functions_table = {} diff --git a/phasm/parser.py b/phasm/parser.py index 9944dc2..2346a71 100644 --- a/phasm/parser.py +++ b/phasm/parser.py @@ -14,10 +14,10 @@ from .ourlang import ( ConstantStruct, ConstantTuple, Expression, - Fold, Function, FunctionCall, FunctionParam, + FunctionReference, Module, ModuleConstantDef, ModuleDataBlock, @@ -446,6 +446,9 @@ class OurVisitor: cdef = module.constant_defs[node.id] return VariableReference(cdef) + if node.id in module.functions: + return FunctionReference(module.functions[node.id]) + _raise_static_error(node, f'Undefined variable {node.id}') if isinstance(node, ast.Tuple): @@ -462,7 +465,7 @@ class OurVisitor: raise NotImplementedError(f'{node} as expr in FunctionDef') - def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[Fold, FunctionCall]: + def visit_Module_FunctionDef_Call(self, module: Module, function: Function, our_locals: OurLocals, node: ast.Call) -> Union[FunctionCall]: if node.keywords: _raise_static_error(node, 'Keyword calling not supported') # Yet? @@ -475,28 +478,6 @@ class OurVisitor: if node.func.id in PRELUDE_METHODS: func = PRELUDE_METHODS[node.func.id] - elif node.func.id == 'foldl': - if 3 != len(node.args): - _raise_static_error(node, f'Function {node.func.id} requires 3 arguments but {len(node.args)} are given') - - # TODO: This is not generic, you cannot return a function - subnode = node.args[0] - if not isinstance(subnode, ast.Name): - raise NotImplementedError(f'Calling methods that are not a name {subnode}') - if not isinstance(subnode.ctx, ast.Load): - _raise_static_error(subnode, 'Must be load context') - if subnode.id not in module.functions: - _raise_static_error(subnode, 'Reference to undefined function') - func = module.functions[subnode.id] - if 2 != len(func.posonlyargs): - _raise_static_error(node, f'Function {node.func.id} requires a function with 2 arguments but a function with {len(func.posonlyargs)} args is given') - - return Fold( - Fold.Direction.LEFT, - func, - self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[1]), - self.visit_Module_FunctionDef_expr(module, function, our_locals, node.args[2]), - ) else: if node.func.id not in module.functions: _raise_static_error(node, 'Call to undefined function') diff --git a/phasm/prelude/__init__.py b/phasm/prelude/__init__.py index e269590..29a266d 100644 --- a/phasm/prelude/__init__.py +++ b/phasm/prelude/__init__.py @@ -20,6 +20,7 @@ from ..type3.types import ( Type3, TypeApplication_Nullary, TypeConstructor_Base, + TypeConstructor_Function, TypeConstructor_StaticArray, TypeConstructor_Struct, TypeConstructor_Tuple, @@ -186,6 +187,16 @@ It should be applied with zero or more arguments. It has a compile time determined length, and each argument can be different. """ +def fn_on_create(args: tuple[Type3, ...], typ: Type3) -> None: + pass # ? instance_type_class(InternalPassAsPointer, typ) + +function = TypeConstructor_Function('function', on_create=fn_on_create) +""" +This is a function. + +It should be applied with one or more arguments. The last argument is the 'return' type. +""" + def st_on_create(args: tuple[tuple[str, Type3], ...], typ: Type3) -> None: instance_type_class(InternalPassAsPointer, typ) @@ -574,12 +585,16 @@ instance_type_class(Promotable, f32, f64, methods={ Foldable = Type3Class('Foldable', (t, ), methods={ 'sum': [t(a), a], + 'foldl': [[b, a, b], b, t(a), b], + 'foldr': [[a, b, b], b, t(a), b], }, operators={}, additional_context={ 'sum': [Constraint_TypeClassInstanceExists(NatNum, (a, ))], }) instance_type_class(Foldable, static_array, methods={ 'sum': stdtypes.static_array_sum, + 'foldl': stdtypes.static_array_foldl, + 'foldr': stdtypes.static_array_foldr, }) PRELUDE_TYPE_CLASSES = { diff --git a/phasm/stdlib/types.py b/phasm/stdlib/types.py index 425dfa0..39c6df3 100644 --- a/phasm/stdlib/types.py +++ b/phasm/stdlib/types.py @@ -4,7 +4,7 @@ stdlib: Standard types that are not wasm primitives from phasm.stdlib import alloc from phasm.type3.routers import TypeVariableLookup from phasm.type3.types import IntType3, Type3 -from phasm.wasmgenerator import Generator, func_wrapper +from phasm.wasmgenerator import Generator, VarType_Base, func_wrapper from phasm.wasmgenerator import VarType_i32 as i32 from phasm.wasmgenerator import VarType_i64 as i64 @@ -1081,9 +1081,17 @@ def f32_f64_demote(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map g.f32.demote_f64() -def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: - assert len(tv_map) == 1 - sa_type, sa_len = next(iter(tv_map.values())) +def static_array_sum(g: Generator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + sa_len = tvn_map['a*'] + assert isinstance(sa_type, Type3) assert isinstance(sa_len, IntType3) @@ -1166,7 +1174,7 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: g.nop(comment='Add array value') g.local.get(sum_adr) g.add_statement(f'{sa_type_mtyp}.load') - sa_type_add_gen(g, {}) + sa_type_add_gen(g, ({}, {}, )) # adr = adr + sa_type_alloc_size # Stack: [sum] -> [sum] @@ -1185,3 +1193,253 @@ def static_array_sum(g: Generator, tv_map: TypeVariableLookup) -> None: g.nop(comment=f'Completed sum for {sa_type.name}[{sa_len.value}]') # End result: [sum] + +def static_array_foldl(g: Generator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + sa_len = tvn_map['a*'] + res_type = tvn_map['b'] + + assert isinstance(sa_type, Type3) + assert isinstance(sa_len, IntType3) + assert isinstance(res_type, Type3) + + if sa_len.value < 1: + raise NotImplementedError('Default value in case foldl is empty') + + # FIXME: We should probably use LOAD_STORE_TYPE_MAP for this? + mtyp_map = { + 'u32': 'i32', + 'u64': 'i64', + 'i32': 'i32', + 'i64': 'i64', + 'f32': 'f32', + 'f64': 'f64', + } + mtyp_f_map: dict[str, type[VarType_Base]] = { + 'i32': i32, + 'i64': i64, + } + + # FIXME: We should probably use calc_alloc_size for this? + type_var_size_map = { + 'u32': 4, + 'u64': 8, + 'i32': 4, + 'i64': 8, + 'f32': 4, + 'f64': 8, + } + + # By default, constructed types are passed as pointers + # FIXME: We don't know what add function to call + sa_type_mtyp = mtyp_map.get(sa_type.name, 'i32') + sa_type_alloc_size = type_var_size_map.get(sa_type.name, 4) + res_type_mtyp = mtyp_map.get(res_type.name, 'i32') + res_type_mtyp_f = mtyp_f_map[res_type_mtyp] + + # Definitions + fold_adr = g.temp_var(i32('fold_adr')) + fold_stop = g.temp_var(i32('fold_stop')) + fold_init = g.temp_var(res_type_mtyp_f('fold_init')) + fold_func = g.temp_var(i32('fold_func')) + + with g.block(params=['i32', res_type_mtyp, 'i32'], result=res_type_mtyp, comment=f'foldl a={sa_type.name} a*={sa_len.value} b={res_type.name}'): + # Stack: [[a, b, b], b, t(a), b] + # t(a) == sa_type[sa_len] + # Stack: [fn*, b, sa*] + + # adr = {address of what's currently on stack} + # Stack: [fn*, b, sa*] -> [fn*, b] + g.local.set(fold_adr) + # Stack: [fn*, b] -> [fn*] + g.local.set(fold_init) + # Stack: [fn*] -> [] + g.local.set(fold_func) + + # stop = adr + ar_len * sa_type_alloc_size + # Stack: [] + g.nop(comment='Calculate address at which to stop looping') + g.local.get(fold_adr) + g.i32.const(sa_len.value * sa_type_alloc_size) + g.i32.add() + g.local.set(fold_stop) + + # Stack: [] -> [b] + g.nop(comment='Get the init value and first array value as starting point') + g.local.get(fold_init) + # Stack: [b] -> [b, *a] + g.local.get(fold_adr) + # Stack: [b] -> [b, a] + g.add_statement(f'{sa_type_mtyp}.load') + g.nop(comment='Call the fold function') + g.local.get(fold_func) + g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})') + + # adr = adr + sa_type_alloc_size + # Stack: [b] -> [b] + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_alloc_size) + g.i32.add() + g.local.set(fold_adr) + + if sa_len.value > 1: + with g.loop(params=[sa_type_mtyp], result=sa_type_mtyp): + # Stack: [b] -> [b, a] + g.nop(comment='Add array value') + g.local.get(fold_adr) + g.add_statement(f'{sa_type_mtyp}.load') + + # Stack [b, a] -> b + g.nop(comment='Call the fold function') + g.local.get(fold_func) + g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})') + + # adr = adr + sa_type_alloc_size + # Stack: [fold] -> [fold] + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_alloc_size) + g.i32.add() + g.local.tee(fold_adr) + + # loop if adr < stop + g.nop(comment='Check if address exceeds array bounds') + g.local.get(fold_stop) + g.i32.lt_u() + g.br_if(0) + # else: just one value, don't need to loop + + # Stack: [b] + +def static_array_foldr(g: Generator, tvl: TypeVariableLookup) -> None: + tv_map, tc_map = tvl + + tvn_map = { + x.name: y + for x, y in tv_map.items() + } + + sa_type = tvn_map['a'] + sa_len = tvn_map['a*'] + res_type = tvn_map['b'] + + assert isinstance(sa_type, Type3) + assert isinstance(sa_len, IntType3) + assert isinstance(res_type, Type3) + + if sa_len.value < 1: + raise NotImplementedError('Default value in case foldl is empty') + + # FIXME: We should probably use LOAD_STORE_TYPE_MAP for this? + mtyp_map = { + 'u32': 'i32', + 'u64': 'i64', + 'i32': 'i32', + 'i64': 'i64', + 'f32': 'f32', + 'f64': 'f64', + } + mtyp_f_map: dict[str, type[VarType_Base]] = { + 'i32': i32, + 'i64': i64, + } + + # FIXME: We should probably use calc_alloc_size for this? + type_var_size_map = { + 'u32': 4, + 'u64': 8, + 'i32': 4, + 'i64': 8, + 'f32': 4, + 'f64': 8, + } + + # By default, constructed types are passed as pointers + # FIXME: We don't know what add function to call + sa_type_mtyp = mtyp_map.get(sa_type.name, 'i32') + sa_type_alloc_size = type_var_size_map.get(sa_type.name, 4) + res_type_mtyp = mtyp_map.get(res_type.name, 'i32') + res_type_mtyp_f = mtyp_f_map[res_type_mtyp] + + # Definitions + fold_adr = g.temp_var(i32('fold_adr')) + fold_stop = g.temp_var(i32('fold_stop')) + fold_init = g.temp_var(res_type_mtyp_f('fold_init')) + fold_func = g.temp_var(i32('fold_func')) + + with g.block(params=['i32', res_type_mtyp, 'i32'], result=res_type_mtyp, comment=f'foldl a={sa_type.name} a*={sa_len.value} b={res_type.name}'): + # Stack: [[a, b, b], b, t(a), b] + # t(a) == sa_type[sa_len] + # Stack: [fn*, b, sa*] + + # adr = {address of what's currently on stack} + # Stack: [fn*, b, sa*] -> [fn*, b] + g.local.set(fold_adr) + # Stack: [fn*, b] -> [fn*] + g.local.set(fold_init) + # Stack: [fn*] -> [] + g.local.set(fold_func) + + # stop = adr + ar_len * sa_type_alloc_size + # Stack: [] + g.nop(comment='Calculate address at which to stop looping') + g.local.get(fold_adr) + g.i32.const(sa_len.value * sa_type_alloc_size) + g.i32.add() + g.local.set(fold_stop) + + # Stack: [] -> [b] + g.nop(comment='Get the init value and first array value as starting point') + g.local.get(fold_init) + # Stack: [b] -> [b, *a] + g.local.get(fold_adr) + # Stack: [b] -> [b, a] + g.add_statement(f'{sa_type_mtyp}.load') + g.nop(comment='Call the fold function') + g.local.get(fold_func) + g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})') + + # adr = adr + sa_type_alloc_size + # Stack: [b] -> [b] + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_alloc_size) + g.i32.add() + g.local.set(fold_adr) + + if sa_len.value > 1: + with g.loop(params=[sa_type_mtyp], result=sa_type_mtyp): + # Stack: [b] -> [b, a] + g.nop(comment='Add array value') + g.local.get(fold_adr) + g.add_statement(f'{sa_type_mtyp}.load') + + # Stack [b, a] -> b + g.nop(comment='Call the fold function') + g.local.get(fold_func) + g.add_statement(f'call_indirect (param {res_type_mtyp} {sa_type_mtyp}) (result {res_type_mtyp})') + + # adr = adr + sa_type_alloc_size + # Stack: [fold] -> [fold] + g.nop(comment='Calculate address of the next value') + g.local.get(fold_adr) + g.i32.const(sa_type_alloc_size) + g.i32.add() + g.local.tee(fold_adr) + + # loop if adr < stop + g.nop(comment='Check if address exceeds array bounds') + g.local.get(fold_stop) + g.i32.lt_u() + g.br_if(0) + # else: just one value, don't need to loop + + # Stack: [b] diff --git a/phasm/type3/constraints.py b/phasm/type3/constraints.py index 4a9541b..4372d50 100644 --- a/phasm/type3/constraints.py +++ b/phasm/type3/constraints.py @@ -6,6 +6,7 @@ These need to be resolved before the program can be compiled. from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from .. import ourlang, prelude +from .functions import FunctionArgument, TypeVariable from .placeholders import PlaceholderForType, Type3OrPlaceholder from .routers import NoRouteForTypeException, TypeApplicationRouter from .typeclasses import Type3Class @@ -158,7 +159,7 @@ class SameTypeConstraint(ConstraintBase): return ( ' == '.join('{t' + str(idx) + '}' for idx in range(len(self.type_list))), { - 't' + str(idx): typ + 't' + str(idx): typ.name if isinstance(typ, ourlang.Function) else typ for idx, typ in enumerate(self.type_list) }, ) @@ -181,7 +182,7 @@ class SameTypeArgumentConstraint(ConstraintBase): self.arg_var = arg_var def check(self) -> CheckResult: - if self.tc_var.resolve_as is None or self.arg_var.resolve_as is None: + if self.tc_var.resolve_as is None: return RequireTypeSubstitutes() tc_typ = self.tc_var.resolve_as @@ -201,13 +202,84 @@ class SameTypeArgumentConstraint(ConstraintBase): # FIXME: This feels sketchy. Shouldn't the type variable # have the exact same number as arguments? if isinstance(tc_typ.application, TypeApplication_TypeInt): - if tc_typ.application.arguments[0] == arg_typ: - return None - - return Error(f'{tc_typ.application.arguments[0]:s} must be {arg_typ:s} instead') + return [SameTypeConstraint( + tc_typ.application.arguments[0], + self.arg_var, + comment=self.comment, + )] raise NotImplementedError(tc_typ, arg_typ) + def human_readable(self) -> HumanReadableRet: + return ( + '{tc_var}` == {arg_var}', + { + 'tc_var': self.tc_var if self.tc_var.resolve_as is None else self.tc_var, + 'arg_var': self.arg_var if self.arg_var.resolve_as is None else self.arg_var, + }, + ) + +class SameFunctionArgumentConstraint(ConstraintBase): + __slots__ = ('type3', 'func_arg', 'type_var_map', ) + + type3: PlaceholderForType + func_arg: FunctionArgument + type_var_map: dict[TypeVariable, PlaceholderForType] + + def __init__(self, type3: PlaceholderForType, func_arg: FunctionArgument, type_var_map: dict[TypeVariable, PlaceholderForType], *, comment: str) -> None: + super().__init__(comment=comment) + + self.type3 = type3 + self.func_arg = func_arg + self.type_var_map = type_var_map + + def check(self) -> CheckResult: + if self.type3.resolve_as is None: + return RequireTypeSubstitutes() + + typ = self.type3.resolve_as + + if isinstance(typ.application, TypeApplication_Nullary): + return Error(f'{typ:s} must be a function instead') + + if not isinstance(typ.application, TypeApplication_TypeStar): + return Error(f'{typ:s} must be a function instead') + + type_var_map = { + x: y.resolve_as + for x, y in self.type_var_map.items() + if y.resolve_as is not None + } + + exp_type_arg_list = [ + tv if isinstance(tv, Type3) else type_var_map[tv] + for tv in self.func_arg.args + if isinstance(tv, Type3) or tv in type_var_map + ] + + print('self.func_arg.args', self.func_arg.args) + print('exp_type_arg_list', exp_type_arg_list) + + if len(exp_type_arg_list) != len(self.func_arg.args): + return RequireTypeSubstitutes() + + return [ + SameTypeConstraint( + typ, + prelude.function(*exp_type_arg_list), + comment=self.comment, + ) + ] + + def human_readable(self) -> HumanReadableRet: + return ( + '{type3} == {func_arg}', + { + 'type3': self.type3, + 'func_arg': self.func_arg.name, + }, + ) + class TupleMatchConstraint(ConstraintBase): __slots__ = ('exp_type', 'args', ) @@ -264,14 +336,10 @@ class MustImplementTypeClassConstraint(ConstraintBase): __slots__ = ('context', 'type_class3', 'types', ) context: Context - type_class3: Union[str, Type3Class] + type_class3: Type3Class types: list[Type3OrPlaceholder] - DATA = { - 'bytes': {'Foldable'}, - } - - def __init__(self, context: Context, type_class3: Union[str, Type3Class], typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None: + def __init__(self, context: Context, type_class3: Type3Class, typ_list: list[Type3OrPlaceholder], comment: Optional[str] = None) -> None: super().__init__(comment=comment) self.context = context @@ -299,13 +367,9 @@ class MustImplementTypeClassConstraint(ConstraintBase): assert len(typ_list) == len(self.types) - if isinstance(self.type_class3, Type3Class): - key = (self.type_class3, tuple(typ_list), ) - if key in self.context.type_class_instances_existing: - return None - else: - if self.type_class3 in self.__class__.DATA.get(typ_list[0].name, set()): - return None + key = (self.type_class3, tuple(typ_list), ) + if key in self.context.type_class_instances_existing: + return None typ_cls_name = self.type_class3 if isinstance(self.type_class3, str) else self.type_class3.name typ_name_list = ' '.join(x.name for x in typ_list) diff --git a/phasm/type3/constraintsgenerator.py b/phasm/type3/constraintsgenerator.py index 01c1549..7f174d5 100644 --- a/phasm/type3/constraintsgenerator.py +++ b/phasm/type3/constraintsgenerator.py @@ -12,12 +12,14 @@ from .constraints import ( Context, LiteralFitsConstraint, MustImplementTypeClassConstraint, + SameFunctionArgumentConstraint, SameTypeArgumentConstraint, SameTypeConstraint, TupleMatchConstraint, ) from .functions import ( Constraint_TypeClassInstanceExists, + FunctionArgument, FunctionSignature, TypeVariable, TypeVariableApplication_Unary, @@ -111,6 +113,33 @@ def _expression_function_call( raise NotImplementedError(constraint) + func_var_map = { + x: PlaceholderForType([]) + for x in signature.args + if isinstance(x, FunctionArgument) + } + + # If some of the function arguments are functions, + # we need to deal with those separately. + for sig_arg in signature.args: + if not isinstance(sig_arg, FunctionArgument): + continue + + # Ensure that for all type variables in the function + # there are also type variables available + for func_arg in sig_arg.args: + if isinstance(func_arg, Type3): + continue + + type_var_map.setdefault(func_arg, PlaceholderForType([])) + + yield SameFunctionArgumentConstraint( + func_var_map[sig_arg], + sig_arg, + type_var_map, + comment=f'Ensure `{sig_arg.name}` matches in {signature}', + ) + # If some of the function arguments are type constructors, # we need to deal with those separately. # That is, given `foo :: t a -> a` we need to ensure @@ -120,6 +149,9 @@ def _expression_function_call( # Not a type variable at all continue + if isinstance(sig_arg, FunctionArgument): + continue + if sig_arg.application.constructor is None: # Not a type variable for a type constructor continue @@ -150,9 +182,20 @@ def _expression_function_call( yield SameTypeConstraint(sig_part, arg_placeholders[arg_expr], comment=comment) continue + if isinstance(sig_part, FunctionArgument): + yield SameTypeConstraint(func_var_map[sig_part], arg_placeholders[arg_expr], comment=comment) + continue + raise NotImplementedError(sig_part) return +def expression_function_reference(ctx: Context, inp: ourlang.FunctionReference, phft: PlaceholderForType) -> ConstraintGenerator: + yield SameTypeConstraint( + prelude.function(*(x.type3 for x in inp.function.posonlyargs), inp.function.returns_type3), + phft, + comment=f'typeOf("{inp.function.name}") == typeOf({inp.function.name})', + ) + def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) -> ConstraintGenerator: if isinstance(inp, ourlang.Constant): yield from constant(ctx, inp, phft) @@ -171,6 +214,10 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) yield from expression_function_call(ctx, inp, phft) return + if isinstance(inp, ourlang.FunctionReference): + yield from expression_function_reference(ctx, inp, phft) + return + if isinstance(inp, ourlang.TupleInstantiation): r_type = [] for arg in inp.elements: @@ -209,19 +256,6 @@ def expression(ctx: Context, inp: ourlang.Expression, phft: PlaceholderForType) comment=f'The type of a struct member reference is the same as the type of struct member {inp.struct_type3.name}.{inp.member}') return - if isinstance(inp, ourlang.Fold): - base_phft = PlaceholderForType([inp.base]) - iter_phft = PlaceholderForType([inp.iter]) - - yield from expression(ctx, inp.base, base_phft) - yield from expression(ctx, inp.iter, iter_phft) - - yield SameTypeConstraint(inp.func.posonlyargs[0].type3, inp.func.returns_type3, base_phft, phft, - comment='foldl :: Foldable t => (b -> a -> b) -> b -> t a -> b') - yield MustImplementTypeClassConstraint(ctx, 'Foldable', [iter_phft]) - - return - raise NotImplementedError(expression, inp) def statement_return(ctx: Context, fun: ourlang.Function, inp: ourlang.StatementReturn) -> ConstraintGenerator: diff --git a/phasm/type3/functions.py b/phasm/type3/functions.py index 1b91948..a49f80b 100644 --- a/phasm/type3/functions.py +++ b/phasm/type3/functions.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Any, Hashable, Iterable, List, Union +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Hashable, Iterable, List if TYPE_CHECKING: from .typeclasses import Type3Class @@ -155,15 +157,29 @@ class TypeVariableContext: def __repr__(self) -> str: return f'TypeVariableContext({self.constraints!r})' +class FunctionArgument: + __slots__ = ('args', 'name', ) + + args: list[Type3 | TypeVariable] + name: str + + def __init__(self, args: list[Type3 | TypeVariable]) -> None: + self.args = args + + self.name = '(' + ' -> '.join(x.name for x in args) + ')' + class FunctionSignature: __slots__ = ('context', 'args', ) context: TypeVariableContext - args: List[Union['Type3', TypeVariable]] + args: List[Type3 | TypeVariable | FunctionArgument] - def __init__(self, context: TypeVariableContext, args: Iterable[Union['Type3', TypeVariable]]) -> None: + def __init__(self, context: TypeVariableContext, args: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]) -> None: self.context = context.__copy__() - self.args = list(args) + self.args = list( + FunctionArgument(x) if isinstance(x, list) else x + for x in args + ) def __str__(self) -> str: return str(self.context) + ' -> '.join(x.name for x in self.args) diff --git a/phasm/type3/routers.py b/phasm/type3/routers.py index c592dac..97d33ad 100644 --- a/phasm/type3/routers.py +++ b/phasm/type3/routers.py @@ -3,6 +3,7 @@ from typing import Any, Callable from .functions import ( TypeConstructorVariable, TypeVariable, + TypeVariableApplication_Nullary, TypeVariableApplication_Unary, ) from .typeclasses import Type3ClassArgs @@ -54,7 +55,10 @@ class TypeApplicationRouter[S, R]: raise NoRouteForTypeException(arg0, typ) -TypeVariableLookup = dict[TypeVariable, tuple[KindArgument, ...]] +TypeVariableLookup = tuple[ + dict[TypeVariable, KindArgument], + dict[TypeConstructorVariable, TypeConstructor_Base[Any]], +] class TypeClassArgsRouter[S, R]: """ @@ -89,11 +93,12 @@ class TypeClassArgsRouter[S, R]: def __call__(self, arg0: S, tv_map: dict[TypeVariable, Type3]) -> R: key: list[Type3 | TypeConstructor_Base[Any]] = [] - arguments: TypeVariableLookup = {} + arguments: TypeVariableLookup = (dict(tv_map), {}, ) for tc_arg in self.args: if isinstance(tc_arg, TypeVariable): key.append(tv_map[tc_arg]) + arguments[0][tc_arg] = tv_map[tc_arg] continue for tvar, typ in tv_map.items(): @@ -102,16 +107,24 @@ class TypeClassArgsRouter[S, R]: continue key.append(typ.application.constructor) + arguments[1][tc_arg] = typ.application.constructor if isinstance(tvar.application, TypeVariableApplication_Unary): # FIXME: This feels sketchy. Shouldn't the type variable # have the exact same number as arguments? if isinstance(typ.application, TypeApplication_TypeInt): - arguments[tvar.application.arguments] = typ.application.arguments + sa_type, sa_len = typ.application.arguments + sa_type_tv = tvar.application.arguments + sa_len_tv = TypeVariable(sa_type_tv.name + '*', TypeVariableApplication_Nullary(None, None)) + + arguments[0][sa_type_tv] = sa_type + arguments[0][sa_len_tv] = sa_len continue raise NotImplementedError(tvar.application, typ.application) + continue + t_helper = self.data.get(tuple(key)) if t_helper is not None: return t_helper(arg0, arguments) diff --git a/phasm/type3/typeclasses.py b/phasm/type3/typeclasses.py index 83d87be..02c6ad0 100644 --- a/phasm/type3/typeclasses.py +++ b/phasm/type3/typeclasses.py @@ -1,4 +1,4 @@ -from typing import Dict, Iterable, List, Mapping, Optional, Union +from typing import Dict, Iterable, List, Mapping, Optional from .functions import ( Constraint_TypeClassInstanceExists, @@ -42,8 +42,8 @@ class Type3Class: self, name: str, args: Type3ClassArgs, - methods: Mapping[str, Iterable[Union[Type3, TypeVariable]]], - operators: Mapping[str, Iterable[Union[Type3, TypeVariable]]], + methods: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]], + operators: Mapping[str, Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]]], inherited_classes: Optional[List['Type3Class']] = None, additional_context: Optional[Mapping[str, Iterable[ConstraintBase]]] = None, ) -> None: @@ -71,19 +71,23 @@ class Type3Class: return self.name def _create_signature( - method_arg_list: Iterable[Type3 | TypeVariable], + method_arg_list: Iterable[Type3 | TypeVariable | list[Type3 | TypeVariable]], type_class3: Type3Class, ) -> FunctionSignature: context = TypeVariableContext() if not isinstance(type_class3.args[0], TypeConstructorVariable): context.constraints.append(Constraint_TypeClassInstanceExists(type_class3, type_class3.args)) - signature_args: list[Type3 | TypeVariable] = [] + signature_args: list[Type3 | TypeVariable | list[Type3 | TypeVariable]] = [] for method_arg in method_arg_list: if isinstance(method_arg, Type3): signature_args.append(method_arg) continue + if isinstance(method_arg, list): + signature_args.append(method_arg) + continue + if isinstance(method_arg, TypeVariable): type_constructor = method_arg.application.constructor if type_constructor is None: diff --git a/phasm/type3/types.py b/phasm/type3/types.py index 7a1b835..a949fb5 100644 --- a/phasm/type3/types.py +++ b/phasm/type3/types.py @@ -239,6 +239,10 @@ class TypeConstructor_Tuple(TypeConstructor_TypeStar): def make_name(self, key: Tuple[Type3, ...]) -> str: return '(' + ', '.join(x.name for x in key) + ', )' +class TypeConstructor_Function(TypeConstructor_TypeStar): + def make_name(self, key: Tuple[Type3, ...]) -> str: + return '(' + ' -> '.join(x.name for x in key) + ')' + class TypeConstructor_Struct(TypeConstructor_Base[tuple[tuple[str, Type3], ...]]): """ Constructs struct types diff --git a/phasm/wasm.py b/phasm/wasm.py index c473e63..9f317fd 100644 --- a/phasm/wasm.py +++ b/phasm/wasm.py @@ -187,14 +187,17 @@ class Module(WatSerializable): def __init__(self) -> None: self.imports: List[Import] = [] self.functions: List[Function] = [] + self.table: dict[int, str] = {} self.memory = ModuleMemory() def to_wat(self) -> str: """ Generates the text version """ - return '(module\n {}\n {}\n {})\n'.format( + return '(module\n {}\n {}\n {}\n {}\n {})\n'.format( '\n '.join(x.to_wat() for x in self.imports), + f'(table {len(self.table)} funcref)', + '\n '.join(f'(elem (i32.const {k}) ${v})' for k, v in self.table.items()), self.memory.to_wat(), '\n '.join(x.to_wat() for x in self.functions), ) diff --git a/phasm/wasmgenerator.py b/phasm/wasmgenerator.py index 9bee38e..faad179 100644 --- a/phasm/wasmgenerator.py +++ b/phasm/wasmgenerator.py @@ -170,11 +170,12 @@ class Generator_Local: self.generator.add_statement('local.tee', variable.name_ref, comment=comment) class GeneratorBlock: - def __init__(self, generator: 'Generator', name: str, params: Iterable[str] = (), result: str | None = None) -> None: + def __init__(self, generator: 'Generator', name: str, params: Iterable[str] = (), result: str | None = None, comment: str | None = None) -> None: self.generator = generator self.name = name self.params = params self.result = result + self.comment = comment def __enter__(self) -> None: stmt = self.name @@ -186,7 +187,7 @@ class GeneratorBlock: if self.result: stmt = f'{stmt} (result {self.result})' - self.generator.add_statement(stmt) + self.generator.add_statement(stmt, comment=self.comment) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if not exc_type: @@ -208,7 +209,7 @@ class Generator: # 2.4.5 Control Instructions self.nop = functools.partial(self.add_statement, 'nop') self.unreachable = functools.partial(self.add_statement, 'unreachable') - # block + self.block = functools.partial(GeneratorBlock, self, 'block') self.loop = functools.partial(GeneratorBlock, self, 'loop') self.if_ = functools.partial(GeneratorBlock, self, 'if') # br diff --git a/tests/integration/test_lang/test_builtins.py b/tests/integration/test_lang/test_builtins.py deleted file mode 100644 index 13400fd..0000000 --- a/tests/integration/test_lang/test_builtins.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from ..helpers import Suite - - -@pytest.mark.integration_test -def test_foldl_1(): - code_py = """ -def u8_or(l: u8, r: u8) -> u8: - return l | r - -@exported -def testEntry(b: bytes) -> u8: - return foldl(u8_or, 128, b) -""" - suite = Suite(code_py) - - result = suite.run_code(b'') - assert 128 == result.returned_value - - result = suite.run_code(b'\x80') - assert 128 == result.returned_value - - result = suite.run_code(b'\x80\x40') - assert 192 == result.returned_value - - result = suite.run_code(b'\x80\x40\x20\x10') - assert 240 == result.returned_value - - result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01') - assert 255 == result.returned_value - -@pytest.mark.integration_test -def test_foldl_2(): - code_py = """ -def xor(l: u8, r: u8) -> u8: - return l ^ r - -@exported -def testEntry(a: bytes, b: bytes) -> u8: - return foldl(xor, 0, a) ^ foldl(xor, 0, b) -""" - suite = Suite(code_py) - - result = suite.run_code(b'\x55\x0F', b'\x33\x80') - assert 233 == result.returned_value - -@pytest.mark.integration_test -def test_foldl_3(): - code_py = """ -def xor(l: u32, r: u8) -> u32: - return l ^ extend(r) - -@exported -def testEntry(a: bytes) -> u32: - return foldl(xor, 0, a) -""" - suite = Suite(code_py) - - result = suite.run_code(b'\x55\x0F\x33\x80') - assert 233 == result.returned_value diff --git a/tests/integration/test_lang/test_foldable.py b/tests/integration/test_lang/test_foldable.py index c03d6e7..e86ecdf 100644 --- a/tests/integration/test_lang/test_foldable.py +++ b/tests/integration/test_lang/test_foldable.py @@ -36,6 +36,118 @@ def testEntry(x: Foo[4]) -> Foo: with pytest.raises(Type3Exception, match='Missing type class instantation: NatNum Foo'): Suite(code_py).run_code() +@pytest.mark.integration_test +@pytest.mark.parametrize('length', [1, 5, 13]) +@pytest.mark.parametrize('direction', ['foldl', 'foldr']) +def test_foldable_foldl_foldr_size(direction, length): + code_py = f""" +def u64_add(l: u64, r: u64) -> u64: + return l + r + +@exported +def testEntry(b: u64[{length}]) -> u64: + return {direction}(u64_add, 100, b) +""" + suite = Suite(code_py) + + in_put = tuple(range(1, length + 1)) + + result = suite.run_code(in_put) + assert (100 + sum(in_put)) == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('direction', ['foldr']) +def test_foldable_foldl_foldr_compounded_type(direction): + code_py = f""" +def combine_foldl(b: u64, a: (u32, u32, )) -> u64: + return extend(a[0] * a[1]) + b + +def combine_foldr(a: (u32, u32, ), b: u64) -> u64: + return extend(a[0] * a[1]) + b + +@exported +def testEntry(b: (u32, u32)[3]) -> u64: + return {direction}(combine_{direction}, 10000, b) +""" + suite = Suite(code_py) + + result = suite.run_code(((2, 5), (25, 4), (125, 8))) + assert 11110 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('direction, exp_result', [ + ('foldl', -55, ), + ('foldr', -5, ), +]) +def test_foldable_foldl_foldr_result(direction, exp_result): + # See https://stackoverflow.com/a/13280185 + code_py = f""" +def i32_sub(l: i32, r: i32) -> i32: + return l - r + +@exported +def testEntry(b: i32[10]) -> i32: + return {direction}(i32_sub, 0, b) +""" + suite = Suite(code_py) + + result = suite.run_code(tuple(range(1, 11))) + assert exp_result == result.returned_value + +@pytest.mark.integration_test +def test_foldable_foldl_bytes(): + code_py = """ +def u8_or(l: u8, r: u8) -> u8: + return l | r + +@exported +def testEntry(b: bytes) -> u8: + return foldl(u8_or, 128, b) +""" + suite = Suite(code_py) + + result = suite.run_code(b'') + assert 128 == result.returned_value + + result = suite.run_code(b'\x80') + assert 128 == result.returned_value + + result = suite.run_code(b'\x80\x40') + assert 192 == result.returned_value + + result = suite.run_code(b'\x80\x40\x20\x10') + assert 240 == result.returned_value + + result = suite.run_code(b'\x80\x40\x20\x10\x08\x04\x02\x01') + assert 255 == result.returned_value + +@pytest.mark.integration_test +@pytest.mark.parametrize('in_typ', ['i8', 'i8[3]']) +def test_foldable_argument_must_be_a_function(in_typ): + code_py = f""" +@exported +def testEntry(x: {in_typ}, y: i32, z: i64[3]) -> i32: + return foldl(x, y, z) +""" + + r_in_typ = in_typ.replace('[', '\\[').replace(']', '\\]') + + with pytest.raises(Type3Exception, match=f'{r_in_typ} must be a function instead'): + Suite(code_py).run_code() + +@pytest.mark.integration_test +def test_foldable_argument_must_be_right_function(): + code_py = """ +def foo(l: i32, r: i64) -> i64: + return extend(l) + r + +@exported +def testEntry(i: i64, l: i64[3]) -> i64: + return foldr(foo, i, l) +""" + + with pytest.raises(Type3Exception, match=r'\(i64 -> i64 -> i64\) must be \(i32 -> i64 -> i64\) instead'): + Suite(code_py).run_code() @pytest.mark.integration_test def test_foldable_invalid_return_type(): @@ -45,7 +157,7 @@ def testEntry(x: i32[5]) -> f64: return sum(x) """ - with pytest.raises(Type3Exception, match='i32 must be f64 instead'): + with pytest.raises(Type3Exception, match='f64 must be i32 instead'): Suite(code_py).run_code((4, 5, 6, 7, 8, )) @pytest.mark.integration_test