Optimise: Remove unused functions

By default, we add a lot of build in functions that may
never get called.

This commit adds a simple reachability graph algorithm
to remove functions that can't be called from outside.

Also, unmarks a lot of functions as being exported. It
was the default to export - now it's the default to not
export.

Also, some general cleanup to the wasm statement calls.
This commit is contained in:
Johan B.W. de Vries 2025-05-25 16:39:25 +02:00
parent 84e7c42ea4
commit d97be81828
10 changed files with 128 additions and 46 deletions

View File

@ -5,6 +5,7 @@ Functions for using this module from CLI
import sys
from .compiler import phasm_compile
from .optimise.removeunusedfuncs import removeunusedfuncs
from .parser import phasm_parse
from .type3.entry import phasm_type3
@ -20,6 +21,7 @@ def main(source: str, sink: str) -> int:
our_module = phasm_parse(code_py)
phasm_type3(our_module, verbose=False)
wasm_module = phasm_compile(our_module)
removeunusedfuncs(wasm_module)
code_wat = wasm_module.to_wat()
with open(sink, 'w') as fil:

View File

@ -325,18 +325,17 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
params = [
type3(x).to_wat()
type3(x)
for x in inp.function.type3.application.arguments
]
result = params.pop()
params_str = ' '.join(params)
wgn.add_statement('local.get', '${}'.format(inp.function.name))
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
wgn.call_indirect(params=params, result=result)
return
wgn.add_statement('call', '${}'.format(inp.function.name))
wgn.call(inp.function.name)
return
if isinstance(inp, ourlang.FunctionReference):

View File

@ -0,0 +1,5 @@
# This folder contains some basic optimisations
# If you really want to optimise your stuff,
# checkout https://github.com/WebAssembly/binaryen
# It's wasm-opt tool can be run on the wat or wasm files.
# And they have a ton of options to optimize.

View File

@ -0,0 +1,49 @@
from phasm import wasm
def removeunusedfuncs(inp: wasm.Module) -> None:
"""
Removes functions that aren't exported and aren't
called by exported functions
"""
# First make a handy lookup table
function_map = {
x.name: x
for x in inp.functions
}
# Keep a list of all functions to retain
retain_functions = set()
# Keep a queue (stack) of the funtions we need to check
# The exported functions as well as the tabled functions are the
# first we know of to keep
to_check_functions = [
x
for x in inp.functions
if x.exported_name is not None
] + [
function_map[x]
for x in inp.table.values()
]
while to_check_functions:
# Check the next function. If it's on the list, we need to retain it
to_check_func = to_check_functions.pop()
retain_functions.add(to_check_func)
# Check all functions calls by this retaining function
to_check_functions.extend(
func
for stmt in to_check_func.statements
if isinstance(stmt, wasm.StatementCall)
# The function_map can not have the named function if it is an import
if (func := function_map.get(stmt.func_name)) is not None
and func not in retain_functions
)
inp.functions = [
func
for func in inp.functions
if func in retain_functions
]

View File

@ -15,7 +15,7 @@ UNALLOC_PTR = ADR_UNALLOC_PTR + 4
# For memory initialization see phasm.compiler.module_data
@func_wrapper(exported=False)
@func_wrapper()
def __find_free_block__(g: Generator, alloc_size: i32) -> i32:
# Find out if we've freed any blocks at all so far
g.i32.const(ADR_FREE_BLOCK_PTR)
@ -32,7 +32,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32:
return i32('return') # To satisfy mypy
@func_wrapper()
@func_wrapper(exported=True)
def __alloc__(g: Generator, alloc_size: i32) -> i32:
result = i32('result')

View File

@ -59,7 +59,7 @@ TYPE_INFO_MAP: Mapping[str, TypeInfo] = {
# not memory pointers but table addresses instead.
TYPE_INFO_CONSTRUCTED = TypeInfo('t a', WasmTypeInt32, 'i32.load', 'i32.store', 4)
@func_wrapper()
@func_wrapper(exported=True)
def __alloc_bytes__(g: Generator, length: i32) -> i32:
"""
Allocates room for a bytes instance, but does not write
@ -604,35 +604,35 @@ def f64_eq_not_equals(g: Generator, tv_map: TypeVariableLookup) -> None:
def u8_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_ord_min__')
g.call('stdlib.types.__u32_ord_min__')
def u16_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_ord_min__')
g.call('stdlib.types.__u32_ord_min__')
def u32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_ord_min__')
g.call('stdlib.types.__u32_ord_min__')
def u64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u64_ord_min__')
g.call('stdlib.types.__u64_ord_min__')
def i8_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i32_ord_min__')
g.call('stdlib.types.__i32_ord_min__')
def i16_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i32_ord_min__')
g.call('stdlib.types.__i32_ord_min__')
def i32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i32_ord_min__')
g.call('stdlib.types.__i32_ord_min__')
def i64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i64_ord_min__')
g.call('stdlib.types.__i64_ord_min__')
def f32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
@ -644,35 +644,35 @@ def f64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
def u8_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_ord_max__')
g.call('stdlib.types.__u32_ord_max__')
def u16_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_ord_max__')
g.call('stdlib.types.__u32_ord_max__')
def u32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_ord_max__')
g.call('stdlib.types.__u32_ord_max__')
def u64_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u64_ord_max__')
g.call('stdlib.types.__u64_ord_max__')
def i8_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i32_ord_max__')
g.call('stdlib.types.__i32_ord_max__')
def i16_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i32_ord_max__')
g.call('stdlib.types.__i32_ord_max__')
def i32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i32_ord_max__')
g.call('stdlib.types.__i32_ord_max__')
def i64_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i64_ord_max__')
g.call('stdlib.types.__i64_ord_max__')
def f32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
@ -886,11 +886,11 @@ def u64_bits_logical_shift_right(g: Generator, tv_map: TypeVariableLookup) -> No
def u8_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u8_rotl__')
g.call('stdlib.types.__u8_rotl__')
def u16_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u16_rotl__')
g.call('stdlib.types.__u16_rotl__')
def u32_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
@ -903,11 +903,11 @@ def u64_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None:
def u8_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u8_rotr__')
g.call('stdlib.types.__u8_rotr__')
def u16_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u16_rotr__')
g.call('stdlib.types.__u16_rotr__')
def u32_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
@ -1150,13 +1150,13 @@ def i64_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) ->
def f32_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_pow2__')
g.call('stdlib.types.__u32_pow2__')
g.f32.convert_i32_u()
g.f32.mul()
def f64_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_pow2__')
g.call('stdlib.types.__u32_pow2__')
g.f64.convert_i32_u()
g.f64.mul()
@ -1180,13 +1180,13 @@ def i64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) ->
def f32_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_pow2__')
g.call('stdlib.types.__u32_pow2__')
g.f32.convert_i32_u()
g.f32.div()
def f64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__u32_pow2__')
g.call('stdlib.types.__u32_pow2__')
g.f64.convert_i32_u()
g.f64.div()
@ -1195,11 +1195,11 @@ def f64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) ->
def i32_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i32_intnum_abs__')
g.call('stdlib.types.__i32_intnum_abs__')
def i64_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map
g.add_statement('call $stdlib.types.__i64_intnum_abs__')
g.call('stdlib.types.__i64_intnum_abs__')
def f32_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None:
del tv_map

View File

@ -105,11 +105,17 @@ class Import(WatSerializable):
else f' (result {self.result.to_wat()})'
)
class Statement(WatSerializable):
class StatementBase(WatSerializable):
pass
class Statement(StatementBase ):
"""
Represents a Web Assembly statement
"""
def __init__(self, name: str, *args: str, comment: Optional[str] = None):
assert ' ' not in name, 'Please pass argument separately'
assert name != 'call', 'Please use StatementCall'
self.name = name
self.args = args
self.comment = comment
@ -120,6 +126,16 @@ class Statement(WatSerializable):
return f'{self.name} {args}{comment}'
class StatementCall(StatementBase ):
def __init__(self, func_name: str, comment: str | None = None):
self.func_name = func_name
self.comment = comment
def to_wat(self) -> str:
comment = f' ;; {self.comment}' if self.comment else ''
return f'call ${self.func_name} {comment}'
class Function(WatSerializable):
"""
Represents a Web Assembly function
@ -131,7 +147,7 @@ class Function(WatSerializable):
params: Iterable[Param],
locals_: Iterable[Param],
result: WasmType,
statements: Iterable[Statement],
statements: Iterable[StatementBase],
) -> None:
self.name = name
self.exported_name = exported_name

View File

@ -196,16 +196,17 @@ class GeneratorBlock:
def __enter__(self) -> None:
stmt = self.name
args: list[str] = []
if self.params:
stmt = f'{stmt} ' + ' '.join(
args.extend(
f'(param {typ})' if isinstance(typ, str) else f'(param {typ().to_wat()})'
for typ in self.params
)
if self.result:
result = self.result if isinstance(self.result, str) else self.result().to_wat()
stmt = f'{stmt} (result {result})'
args.append(f'(result {result})')
self.generator.add_statement(stmt, comment=self.comment)
self.generator.add_statement(stmt, *args, comment=self.comment)
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if not exc_type:
@ -213,7 +214,7 @@ class GeneratorBlock:
class Generator:
def __init__(self) -> None:
self.statements: List[wasm.Statement] = []
self.statements: List[wasm.StatementBase] = []
self.locals: Dict[str, VarType_Base] = {}
self.i32 = Generator_i32(self)
@ -238,7 +239,7 @@ class Generator:
# br_table
self.return_ = functools.partial(self.add_statement, 'return')
# call - see below
# call_indirect
# call_indirect - see below
def br_if(self, idx: int) -> None:
self.add_statement('br_if', f'{idx}')
@ -247,17 +248,19 @@ class Generator:
if isinstance(function, wasm.Function):
function = function.name
self.add_statement('call', f'${function}')
self.statements.append(wasm.StatementCall(function))
def call_indirect(self, params: Iterable[Type[wasm.WasmType]], result: Type[wasm.WasmType]) -> None:
def call_indirect(self, params: Iterable[Type[wasm.WasmType] | wasm.WasmType], result: Type[wasm.WasmType] | wasm.WasmType) -> None:
param_str = ' '.join(
x().to_wat()
(x() if isinstance(x, type) else x).to_wat()
for x in params
)
result_str = result().to_wat()
if isinstance(result, type):
result = result()
result_str = result.to_wat()
self.add_statement(f'call_indirect (param {param_str}) (result {result_str})')
self.add_statement('call_indirect', f'(param {param_str})', f'(result {result_str})')
def add_statement(self, name: str, *args: str, comment: Optional[str] = None) -> None:
self.statements.append(wasm.Statement(name, *args, comment=comment))
@ -297,7 +300,7 @@ class Generator:
def temp_var_u8(self, infix: str) -> VarType_u8:
return self.temp_var(VarType_u8(infix))
def func_wrapper(exported: bool = True) -> Callable[[Any], wasm.Function]:
def func_wrapper(exported: bool = False) -> Callable[[Any], wasm.Function]:
"""
This wrapper will execute the function and return
a wasm Function with the generated Statements

View File

@ -100,6 +100,7 @@ class Suite:
runner.parse(verbose=verbose)
runner.compile_ast()
runner.optimise_wasm_ast()
runner.compile_wat()
if verbose:

View File

@ -8,6 +8,7 @@ import wasmtime
from phasm import ourlang, wasm
from phasm.compiler import phasm_compile
from phasm.optimise.removeunusedfuncs import removeunusedfuncs
from phasm.parser import phasm_parse
from phasm.type3.entry import phasm_type3
@ -45,6 +46,12 @@ class RunnerBase:
"""
self.wasm_ast = phasm_compile(self.phasm_ast)
def optimise_wasm_ast(self) -> None:
"""
Optimises the WebAssembly AST
"""
removeunusedfuncs(self.wasm_ast)
def compile_wat(self) -> None:
"""
Compiles the WebAssembly AST into WebAssembly Assembly code