diff --git a/phasm/__main__.py b/phasm/__main__.py index 8afa522..a4e6a31 100644 --- a/phasm/__main__.py +++ b/phasm/__main__.py @@ -5,6 +5,7 @@ Functions for using this module from CLI import sys from .compiler import phasm_compile +from .optimise.removeunusedfuncs import removeunusedfuncs from .parser import phasm_parse from .type3.entry import phasm_type3 @@ -20,6 +21,7 @@ def main(source: str, sink: str) -> int: our_module = phasm_parse(code_py) phasm_type3(our_module, verbose=False) wasm_module = phasm_compile(our_module) + removeunusedfuncs(wasm_module) code_wat = wasm_module.to_wat() with open(sink, 'w') as fil: diff --git a/phasm/compiler.py b/phasm/compiler.py index 79beba3..e2df411 100644 --- a/phasm/compiler.py +++ b/phasm/compiler.py @@ -325,18 +325,17 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression) assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function) params = [ - type3(x).to_wat() + type3(x) for x in inp.function.type3.application.arguments ] result = params.pop() - params_str = ' '.join(params) wgn.add_statement('local.get', '${}'.format(inp.function.name)) - wgn.add_statement(f'call_indirect (param {params_str}) (result {result})') + wgn.call_indirect(params=params, result=result) return - wgn.add_statement('call', '${}'.format(inp.function.name)) + wgn.call(inp.function.name) return if isinstance(inp, ourlang.FunctionReference): diff --git a/phasm/optimise/__init__.py b/phasm/optimise/__init__.py new file mode 100644 index 0000000..8bcc333 --- /dev/null +++ b/phasm/optimise/__init__.py @@ -0,0 +1,5 @@ +# This folder contains some basic optimisations +# If you really want to optimise your stuff, +# checkout https://github.com/WebAssembly/binaryen +# It's wasm-opt tool can be run on the wat or wasm files. +# And they have a ton of options to optimize. diff --git a/phasm/optimise/removeunusedfuncs.py b/phasm/optimise/removeunusedfuncs.py new file mode 100644 index 0000000..e08c3ff --- /dev/null +++ b/phasm/optimise/removeunusedfuncs.py @@ -0,0 +1,49 @@ +from phasm import wasm + + +def removeunusedfuncs(inp: wasm.Module) -> None: + """ + Removes functions that aren't exported and aren't + called by exported functions + """ + # First make a handy lookup table + function_map = { + x.name: x + for x in inp.functions + } + + # Keep a list of all functions to retain + retain_functions = set() + + # Keep a queue (stack) of the funtions we need to check + # The exported functions as well as the tabled functions are the + # first we know of to keep + to_check_functions = [ + x + for x in inp.functions + if x.exported_name is not None + ] + [ + function_map[x] + for x in inp.table.values() + ] + + while to_check_functions: + # Check the next function. If it's on the list, we need to retain it + to_check_func = to_check_functions.pop() + retain_functions.add(to_check_func) + + # Check all functions calls by this retaining function + to_check_functions.extend( + func + for stmt in to_check_func.statements + if isinstance(stmt, wasm.StatementCall) + # The function_map can not have the named function if it is an import + if (func := function_map.get(stmt.func_name)) is not None + and func not in retain_functions + ) + + inp.functions = [ + func + for func in inp.functions + if func in retain_functions + ] diff --git a/phasm/stdlib/alloc.py b/phasm/stdlib/alloc.py index 4473d13..43f7675 100644 --- a/phasm/stdlib/alloc.py +++ b/phasm/stdlib/alloc.py @@ -15,7 +15,7 @@ UNALLOC_PTR = ADR_UNALLOC_PTR + 4 # For memory initialization see phasm.compiler.module_data -@func_wrapper(exported=False) +@func_wrapper() def __find_free_block__(g: Generator, alloc_size: i32) -> i32: # Find out if we've freed any blocks at all so far g.i32.const(ADR_FREE_BLOCK_PTR) @@ -32,7 +32,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32: return i32('return') # To satisfy mypy -@func_wrapper() +@func_wrapper(exported=True) def __alloc__(g: Generator, alloc_size: i32) -> i32: result = i32('result') diff --git a/phasm/stdlib/types.py b/phasm/stdlib/types.py index f8961d1..38c5c98 100644 --- a/phasm/stdlib/types.py +++ b/phasm/stdlib/types.py @@ -59,7 +59,7 @@ TYPE_INFO_MAP: Mapping[str, TypeInfo] = { # not memory pointers but table addresses instead. TYPE_INFO_CONSTRUCTED = TypeInfo('t a', WasmTypeInt32, 'i32.load', 'i32.store', 4) -@func_wrapper() +@func_wrapper(exported=True) def __alloc_bytes__(g: Generator, length: i32) -> i32: """ Allocates room for a bytes instance, but does not write @@ -604,35 +604,35 @@ def f64_eq_not_equals(g: Generator, tv_map: TypeVariableLookup) -> None: def u8_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_ord_min__') + g.call('stdlib.types.__u32_ord_min__') def u16_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_ord_min__') + g.call('stdlib.types.__u32_ord_min__') def u32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_ord_min__') + g.call('stdlib.types.__u32_ord_min__') def u64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u64_ord_min__') + g.call('stdlib.types.__u64_ord_min__') def i8_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i32_ord_min__') + g.call('stdlib.types.__i32_ord_min__') def i16_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i32_ord_min__') + g.call('stdlib.types.__i32_ord_min__') def i32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i32_ord_min__') + g.call('stdlib.types.__i32_ord_min__') def i64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i64_ord_min__') + g.call('stdlib.types.__i64_ord_min__') def f32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map @@ -644,35 +644,35 @@ def f64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None: def u8_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_ord_max__') + g.call('stdlib.types.__u32_ord_max__') def u16_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_ord_max__') + g.call('stdlib.types.__u32_ord_max__') def u32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_ord_max__') + g.call('stdlib.types.__u32_ord_max__') def u64_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u64_ord_max__') + g.call('stdlib.types.__u64_ord_max__') def i8_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i32_ord_max__') + g.call('stdlib.types.__i32_ord_max__') def i16_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i32_ord_max__') + g.call('stdlib.types.__i32_ord_max__') def i32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i32_ord_max__') + g.call('stdlib.types.__i32_ord_max__') def i64_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i64_ord_max__') + g.call('stdlib.types.__i64_ord_max__') def f32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map @@ -886,11 +886,11 @@ def u64_bits_logical_shift_right(g: Generator, tv_map: TypeVariableLookup) -> No def u8_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u8_rotl__') + g.call('stdlib.types.__u8_rotl__') def u16_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u16_rotl__') + g.call('stdlib.types.__u16_rotl__') def u32_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map @@ -903,11 +903,11 @@ def u64_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None: def u8_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u8_rotr__') + g.call('stdlib.types.__u8_rotr__') def u16_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u16_rotr__') + g.call('stdlib.types.__u16_rotr__') def u32_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map @@ -1150,13 +1150,13 @@ def i64_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) -> def f32_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_pow2__') + g.call('stdlib.types.__u32_pow2__') g.f32.convert_i32_u() g.f32.mul() def f64_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_pow2__') + g.call('stdlib.types.__u32_pow2__') g.f64.convert_i32_u() g.f64.mul() @@ -1180,13 +1180,13 @@ def i64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) -> def f32_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_pow2__') + g.call('stdlib.types.__u32_pow2__') g.f32.convert_i32_u() g.f32.div() def f64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__u32_pow2__') + g.call('stdlib.types.__u32_pow2__') g.f64.convert_i32_u() g.f64.div() @@ -1195,11 +1195,11 @@ def f64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) -> def i32_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i32_intnum_abs__') + g.call('stdlib.types.__i32_intnum_abs__') def i64_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map - g.add_statement('call $stdlib.types.__i64_intnum_abs__') + g.call('stdlib.types.__i64_intnum_abs__') def f32_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None: del tv_map diff --git a/phasm/wasm.py b/phasm/wasm.py index 2750ee3..0af277c 100644 --- a/phasm/wasm.py +++ b/phasm/wasm.py @@ -105,11 +105,17 @@ class Import(WatSerializable): else f' (result {self.result.to_wat()})' ) -class Statement(WatSerializable): +class StatementBase(WatSerializable): + pass + +class Statement(StatementBase ): """ Represents a Web Assembly statement """ def __init__(self, name: str, *args: str, comment: Optional[str] = None): + assert ' ' not in name, 'Please pass argument separately' + assert name != 'call', 'Please use StatementCall' + self.name = name self.args = args self.comment = comment @@ -120,6 +126,16 @@ class Statement(WatSerializable): return f'{self.name} {args}{comment}' +class StatementCall(StatementBase ): + def __init__(self, func_name: str, comment: str | None = None): + self.func_name = func_name + self.comment = comment + + def to_wat(self) -> str: + comment = f' ;; {self.comment}' if self.comment else '' + + return f'call ${self.func_name} {comment}' + class Function(WatSerializable): """ Represents a Web Assembly function @@ -131,7 +147,7 @@ class Function(WatSerializable): params: Iterable[Param], locals_: Iterable[Param], result: WasmType, - statements: Iterable[Statement], + statements: Iterable[StatementBase], ) -> None: self.name = name self.exported_name = exported_name diff --git a/phasm/wasmgenerator.py b/phasm/wasmgenerator.py index 2560e0a..4337fb1 100644 --- a/phasm/wasmgenerator.py +++ b/phasm/wasmgenerator.py @@ -196,16 +196,17 @@ class GeneratorBlock: def __enter__(self) -> None: stmt = self.name + args: list[str] = [] if self.params: - stmt = f'{stmt} ' + ' '.join( + args.extend( f'(param {typ})' if isinstance(typ, str) else f'(param {typ().to_wat()})' for typ in self.params ) if self.result: result = self.result if isinstance(self.result, str) else self.result().to_wat() - stmt = f'{stmt} (result {result})' + args.append(f'(result {result})') - self.generator.add_statement(stmt, comment=self.comment) + self.generator.add_statement(stmt, *args, comment=self.comment) def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if not exc_type: @@ -213,7 +214,7 @@ class GeneratorBlock: class Generator: def __init__(self) -> None: - self.statements: List[wasm.Statement] = [] + self.statements: List[wasm.StatementBase] = [] self.locals: Dict[str, VarType_Base] = {} self.i32 = Generator_i32(self) @@ -238,7 +239,7 @@ class Generator: # br_table self.return_ = functools.partial(self.add_statement, 'return') # call - see below - # call_indirect + # call_indirect - see below def br_if(self, idx: int) -> None: self.add_statement('br_if', f'{idx}') @@ -247,17 +248,19 @@ class Generator: if isinstance(function, wasm.Function): function = function.name - self.add_statement('call', f'${function}') + self.statements.append(wasm.StatementCall(function)) - def call_indirect(self, params: Iterable[Type[wasm.WasmType]], result: Type[wasm.WasmType]) -> None: + def call_indirect(self, params: Iterable[Type[wasm.WasmType] | wasm.WasmType], result: Type[wasm.WasmType] | wasm.WasmType) -> None: param_str = ' '.join( - x().to_wat() + (x() if isinstance(x, type) else x).to_wat() for x in params ) - result_str = result().to_wat() + if isinstance(result, type): + result = result() + result_str = result.to_wat() - self.add_statement(f'call_indirect (param {param_str}) (result {result_str})') + self.add_statement('call_indirect', f'(param {param_str})', f'(result {result_str})') def add_statement(self, name: str, *args: str, comment: Optional[str] = None) -> None: self.statements.append(wasm.Statement(name, *args, comment=comment)) @@ -297,7 +300,7 @@ class Generator: def temp_var_u8(self, infix: str) -> VarType_u8: return self.temp_var(VarType_u8(infix)) -def func_wrapper(exported: bool = True) -> Callable[[Any], wasm.Function]: +def func_wrapper(exported: bool = False) -> Callable[[Any], wasm.Function]: """ This wrapper will execute the function and return a wasm Function with the generated Statements diff --git a/tests/integration/helpers.py b/tests/integration/helpers.py index 907aa1c..192ee21 100644 --- a/tests/integration/helpers.py +++ b/tests/integration/helpers.py @@ -100,6 +100,7 @@ class Suite: runner.parse(verbose=verbose) runner.compile_ast() + runner.optimise_wasm_ast() runner.compile_wat() if verbose: diff --git a/tests/integration/runners.py b/tests/integration/runners.py index 3957233..60835bd 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -8,6 +8,7 @@ import wasmtime from phasm import ourlang, wasm from phasm.compiler import phasm_compile +from phasm.optimise.removeunusedfuncs import removeunusedfuncs from phasm.parser import phasm_parse from phasm.type3.entry import phasm_type3 @@ -45,6 +46,12 @@ class RunnerBase: """ self.wasm_ast = phasm_compile(self.phasm_ast) + def optimise_wasm_ast(self) -> None: + """ + Optimises the WebAssembly AST + """ + removeunusedfuncs(self.wasm_ast) + def compile_wat(self) -> None: """ Compiles the WebAssembly AST into WebAssembly Assembly code