Optimise: Remove unused functions
By default, we add a lot of build in functions that may never get called. This commit adds a simple reachability graph algorithm to remove functions that can't be called from outside. Also, unmarks a lot of functions as being exported. It was the default to export - now it's the default to not export. Also, some general cleanup to the wasm statement calls.
This commit is contained in:
parent
84e7c42ea4
commit
d97be81828
@ -5,6 +5,7 @@ Functions for using this module from CLI
|
||||
import sys
|
||||
|
||||
from .compiler import phasm_compile
|
||||
from .optimise.removeunusedfuncs import removeunusedfuncs
|
||||
from .parser import phasm_parse
|
||||
from .type3.entry import phasm_type3
|
||||
|
||||
@ -20,6 +21,7 @@ def main(source: str, sink: str) -> int:
|
||||
our_module = phasm_parse(code_py)
|
||||
phasm_type3(our_module, verbose=False)
|
||||
wasm_module = phasm_compile(our_module)
|
||||
removeunusedfuncs(wasm_module)
|
||||
code_wat = wasm_module.to_wat()
|
||||
|
||||
with open(sink, 'w') as fil:
|
||||
|
||||
@ -325,18 +325,17 @@ def expression(wgn: WasmGenerator, mod: ourlang.Module, inp: ourlang.Expression)
|
||||
assert isinstance(inp.function.type3.application.constructor, TypeConstructor_Function)
|
||||
|
||||
params = [
|
||||
type3(x).to_wat()
|
||||
type3(x)
|
||||
for x in inp.function.type3.application.arguments
|
||||
]
|
||||
|
||||
result = params.pop()
|
||||
|
||||
params_str = ' '.join(params)
|
||||
wgn.add_statement('local.get', '${}'.format(inp.function.name))
|
||||
wgn.add_statement(f'call_indirect (param {params_str}) (result {result})')
|
||||
wgn.call_indirect(params=params, result=result)
|
||||
return
|
||||
|
||||
wgn.add_statement('call', '${}'.format(inp.function.name))
|
||||
wgn.call(inp.function.name)
|
||||
return
|
||||
|
||||
if isinstance(inp, ourlang.FunctionReference):
|
||||
|
||||
5
phasm/optimise/__init__.py
Normal file
5
phasm/optimise/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# This folder contains some basic optimisations
|
||||
# If you really want to optimise your stuff,
|
||||
# checkout https://github.com/WebAssembly/binaryen
|
||||
# It's wasm-opt tool can be run on the wat or wasm files.
|
||||
# And they have a ton of options to optimize.
|
||||
49
phasm/optimise/removeunusedfuncs.py
Normal file
49
phasm/optimise/removeunusedfuncs.py
Normal file
@ -0,0 +1,49 @@
|
||||
from phasm import wasm
|
||||
|
||||
|
||||
def removeunusedfuncs(inp: wasm.Module) -> None:
|
||||
"""
|
||||
Removes functions that aren't exported and aren't
|
||||
called by exported functions
|
||||
"""
|
||||
# First make a handy lookup table
|
||||
function_map = {
|
||||
x.name: x
|
||||
for x in inp.functions
|
||||
}
|
||||
|
||||
# Keep a list of all functions to retain
|
||||
retain_functions = set()
|
||||
|
||||
# Keep a queue (stack) of the funtions we need to check
|
||||
# The exported functions as well as the tabled functions are the
|
||||
# first we know of to keep
|
||||
to_check_functions = [
|
||||
x
|
||||
for x in inp.functions
|
||||
if x.exported_name is not None
|
||||
] + [
|
||||
function_map[x]
|
||||
for x in inp.table.values()
|
||||
]
|
||||
|
||||
while to_check_functions:
|
||||
# Check the next function. If it's on the list, we need to retain it
|
||||
to_check_func = to_check_functions.pop()
|
||||
retain_functions.add(to_check_func)
|
||||
|
||||
# Check all functions calls by this retaining function
|
||||
to_check_functions.extend(
|
||||
func
|
||||
for stmt in to_check_func.statements
|
||||
if isinstance(stmt, wasm.StatementCall)
|
||||
# The function_map can not have the named function if it is an import
|
||||
if (func := function_map.get(stmt.func_name)) is not None
|
||||
and func not in retain_functions
|
||||
)
|
||||
|
||||
inp.functions = [
|
||||
func
|
||||
for func in inp.functions
|
||||
if func in retain_functions
|
||||
]
|
||||
@ -15,7 +15,7 @@ UNALLOC_PTR = ADR_UNALLOC_PTR + 4
|
||||
|
||||
# For memory initialization see phasm.compiler.module_data
|
||||
|
||||
@func_wrapper(exported=False)
|
||||
@func_wrapper()
|
||||
def __find_free_block__(g: Generator, alloc_size: i32) -> i32:
|
||||
# Find out if we've freed any blocks at all so far
|
||||
g.i32.const(ADR_FREE_BLOCK_PTR)
|
||||
@ -32,7 +32,7 @@ def __find_free_block__(g: Generator, alloc_size: i32) -> i32:
|
||||
|
||||
return i32('return') # To satisfy mypy
|
||||
|
||||
@func_wrapper()
|
||||
@func_wrapper(exported=True)
|
||||
def __alloc__(g: Generator, alloc_size: i32) -> i32:
|
||||
result = i32('result')
|
||||
|
||||
|
||||
@ -59,7 +59,7 @@ TYPE_INFO_MAP: Mapping[str, TypeInfo] = {
|
||||
# not memory pointers but table addresses instead.
|
||||
TYPE_INFO_CONSTRUCTED = TypeInfo('t a', WasmTypeInt32, 'i32.load', 'i32.store', 4)
|
||||
|
||||
@func_wrapper()
|
||||
@func_wrapper(exported=True)
|
||||
def __alloc_bytes__(g: Generator, length: i32) -> i32:
|
||||
"""
|
||||
Allocates room for a bytes instance, but does not write
|
||||
@ -604,35 +604,35 @@ def f64_eq_not_equals(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
|
||||
def u8_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_ord_min__')
|
||||
g.call('stdlib.types.__u32_ord_min__')
|
||||
|
||||
def u16_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_ord_min__')
|
||||
g.call('stdlib.types.__u32_ord_min__')
|
||||
|
||||
def u32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_ord_min__')
|
||||
g.call('stdlib.types.__u32_ord_min__')
|
||||
|
||||
def u64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u64_ord_min__')
|
||||
g.call('stdlib.types.__u64_ord_min__')
|
||||
|
||||
def i8_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i32_ord_min__')
|
||||
g.call('stdlib.types.__i32_ord_min__')
|
||||
|
||||
def i16_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i32_ord_min__')
|
||||
g.call('stdlib.types.__i32_ord_min__')
|
||||
|
||||
def i32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i32_ord_min__')
|
||||
g.call('stdlib.types.__i32_ord_min__')
|
||||
|
||||
def i64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i64_ord_min__')
|
||||
g.call('stdlib.types.__i64_ord_min__')
|
||||
|
||||
def f32_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
@ -644,35 +644,35 @@ def f64_ord_min(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
|
||||
def u8_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_ord_max__')
|
||||
g.call('stdlib.types.__u32_ord_max__')
|
||||
|
||||
def u16_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_ord_max__')
|
||||
g.call('stdlib.types.__u32_ord_max__')
|
||||
|
||||
def u32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_ord_max__')
|
||||
g.call('stdlib.types.__u32_ord_max__')
|
||||
|
||||
def u64_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u64_ord_max__')
|
||||
g.call('stdlib.types.__u64_ord_max__')
|
||||
|
||||
def i8_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i32_ord_max__')
|
||||
g.call('stdlib.types.__i32_ord_max__')
|
||||
|
||||
def i16_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i32_ord_max__')
|
||||
g.call('stdlib.types.__i32_ord_max__')
|
||||
|
||||
def i32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i32_ord_max__')
|
||||
g.call('stdlib.types.__i32_ord_max__')
|
||||
|
||||
def i64_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i64_ord_max__')
|
||||
g.call('stdlib.types.__i64_ord_max__')
|
||||
|
||||
def f32_ord_max(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
@ -886,11 +886,11 @@ def u64_bits_logical_shift_right(g: Generator, tv_map: TypeVariableLookup) -> No
|
||||
|
||||
def u8_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u8_rotl__')
|
||||
g.call('stdlib.types.__u8_rotl__')
|
||||
|
||||
def u16_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u16_rotl__')
|
||||
g.call('stdlib.types.__u16_rotl__')
|
||||
|
||||
def u32_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
@ -903,11 +903,11 @@ def u64_bits_rotate_left(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
|
||||
def u8_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u8_rotr__')
|
||||
g.call('stdlib.types.__u8_rotr__')
|
||||
|
||||
def u16_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u16_rotr__')
|
||||
g.call('stdlib.types.__u16_rotr__')
|
||||
|
||||
def u32_bits_rotate_right(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
@ -1150,13 +1150,13 @@ def i64_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) ->
|
||||
|
||||
def f32_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_pow2__')
|
||||
g.call('stdlib.types.__u32_pow2__')
|
||||
g.f32.convert_i32_u()
|
||||
g.f32.mul()
|
||||
|
||||
def f64_natnum_arithmic_shift_left(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_pow2__')
|
||||
g.call('stdlib.types.__u32_pow2__')
|
||||
g.f64.convert_i32_u()
|
||||
g.f64.mul()
|
||||
|
||||
@ -1180,13 +1180,13 @@ def i64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) ->
|
||||
|
||||
def f32_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_pow2__')
|
||||
g.call('stdlib.types.__u32_pow2__')
|
||||
g.f32.convert_i32_u()
|
||||
g.f32.div()
|
||||
|
||||
def f64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__u32_pow2__')
|
||||
g.call('stdlib.types.__u32_pow2__')
|
||||
g.f64.convert_i32_u()
|
||||
g.f64.div()
|
||||
|
||||
@ -1195,11 +1195,11 @@ def f64_natnum_arithmic_shift_right(g: Generator, tv_map: TypeVariableLookup) ->
|
||||
|
||||
def i32_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i32_intnum_abs__')
|
||||
g.call('stdlib.types.__i32_intnum_abs__')
|
||||
|
||||
def i64_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
g.add_statement('call $stdlib.types.__i64_intnum_abs__')
|
||||
g.call('stdlib.types.__i64_intnum_abs__')
|
||||
|
||||
def f32_intnum_abs(g: Generator, tv_map: TypeVariableLookup) -> None:
|
||||
del tv_map
|
||||
|
||||
@ -105,11 +105,17 @@ class Import(WatSerializable):
|
||||
else f' (result {self.result.to_wat()})'
|
||||
)
|
||||
|
||||
class Statement(WatSerializable):
|
||||
class StatementBase(WatSerializable):
|
||||
pass
|
||||
|
||||
class Statement(StatementBase ):
|
||||
"""
|
||||
Represents a Web Assembly statement
|
||||
"""
|
||||
def __init__(self, name: str, *args: str, comment: Optional[str] = None):
|
||||
assert ' ' not in name, 'Please pass argument separately'
|
||||
assert name != 'call', 'Please use StatementCall'
|
||||
|
||||
self.name = name
|
||||
self.args = args
|
||||
self.comment = comment
|
||||
@ -120,6 +126,16 @@ class Statement(WatSerializable):
|
||||
|
||||
return f'{self.name} {args}{comment}'
|
||||
|
||||
class StatementCall(StatementBase ):
|
||||
def __init__(self, func_name: str, comment: str | None = None):
|
||||
self.func_name = func_name
|
||||
self.comment = comment
|
||||
|
||||
def to_wat(self) -> str:
|
||||
comment = f' ;; {self.comment}' if self.comment else ''
|
||||
|
||||
return f'call ${self.func_name} {comment}'
|
||||
|
||||
class Function(WatSerializable):
|
||||
"""
|
||||
Represents a Web Assembly function
|
||||
@ -131,7 +147,7 @@ class Function(WatSerializable):
|
||||
params: Iterable[Param],
|
||||
locals_: Iterable[Param],
|
||||
result: WasmType,
|
||||
statements: Iterable[Statement],
|
||||
statements: Iterable[StatementBase],
|
||||
) -> None:
|
||||
self.name = name
|
||||
self.exported_name = exported_name
|
||||
|
||||
@ -196,16 +196,17 @@ class GeneratorBlock:
|
||||
|
||||
def __enter__(self) -> None:
|
||||
stmt = self.name
|
||||
args: list[str] = []
|
||||
if self.params:
|
||||
stmt = f'{stmt} ' + ' '.join(
|
||||
args.extend(
|
||||
f'(param {typ})' if isinstance(typ, str) else f'(param {typ().to_wat()})'
|
||||
for typ in self.params
|
||||
)
|
||||
if self.result:
|
||||
result = self.result if isinstance(self.result, str) else self.result().to_wat()
|
||||
stmt = f'{stmt} (result {result})'
|
||||
args.append(f'(result {result})')
|
||||
|
||||
self.generator.add_statement(stmt, comment=self.comment)
|
||||
self.generator.add_statement(stmt, *args, comment=self.comment)
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
if not exc_type:
|
||||
@ -213,7 +214,7 @@ class GeneratorBlock:
|
||||
|
||||
class Generator:
|
||||
def __init__(self) -> None:
|
||||
self.statements: List[wasm.Statement] = []
|
||||
self.statements: List[wasm.StatementBase] = []
|
||||
self.locals: Dict[str, VarType_Base] = {}
|
||||
|
||||
self.i32 = Generator_i32(self)
|
||||
@ -238,7 +239,7 @@ class Generator:
|
||||
# br_table
|
||||
self.return_ = functools.partial(self.add_statement, 'return')
|
||||
# call - see below
|
||||
# call_indirect
|
||||
# call_indirect - see below
|
||||
|
||||
def br_if(self, idx: int) -> None:
|
||||
self.add_statement('br_if', f'{idx}')
|
||||
@ -247,17 +248,19 @@ class Generator:
|
||||
if isinstance(function, wasm.Function):
|
||||
function = function.name
|
||||
|
||||
self.add_statement('call', f'${function}')
|
||||
self.statements.append(wasm.StatementCall(function))
|
||||
|
||||
def call_indirect(self, params: Iterable[Type[wasm.WasmType]], result: Type[wasm.WasmType]) -> None:
|
||||
def call_indirect(self, params: Iterable[Type[wasm.WasmType] | wasm.WasmType], result: Type[wasm.WasmType] | wasm.WasmType) -> None:
|
||||
param_str = ' '.join(
|
||||
x().to_wat()
|
||||
(x() if isinstance(x, type) else x).to_wat()
|
||||
for x in params
|
||||
)
|
||||
|
||||
result_str = result().to_wat()
|
||||
if isinstance(result, type):
|
||||
result = result()
|
||||
result_str = result.to_wat()
|
||||
|
||||
self.add_statement(f'call_indirect (param {param_str}) (result {result_str})')
|
||||
self.add_statement('call_indirect', f'(param {param_str})', f'(result {result_str})')
|
||||
|
||||
def add_statement(self, name: str, *args: str, comment: Optional[str] = None) -> None:
|
||||
self.statements.append(wasm.Statement(name, *args, comment=comment))
|
||||
@ -297,7 +300,7 @@ class Generator:
|
||||
def temp_var_u8(self, infix: str) -> VarType_u8:
|
||||
return self.temp_var(VarType_u8(infix))
|
||||
|
||||
def func_wrapper(exported: bool = True) -> Callable[[Any], wasm.Function]:
|
||||
def func_wrapper(exported: bool = False) -> Callable[[Any], wasm.Function]:
|
||||
"""
|
||||
This wrapper will execute the function and return
|
||||
a wasm Function with the generated Statements
|
||||
|
||||
@ -100,6 +100,7 @@ class Suite:
|
||||
|
||||
runner.parse(verbose=verbose)
|
||||
runner.compile_ast()
|
||||
runner.optimise_wasm_ast()
|
||||
runner.compile_wat()
|
||||
|
||||
if verbose:
|
||||
|
||||
@ -8,6 +8,7 @@ import wasmtime
|
||||
|
||||
from phasm import ourlang, wasm
|
||||
from phasm.compiler import phasm_compile
|
||||
from phasm.optimise.removeunusedfuncs import removeunusedfuncs
|
||||
from phasm.parser import phasm_parse
|
||||
from phasm.type3.entry import phasm_type3
|
||||
|
||||
@ -45,6 +46,12 @@ class RunnerBase:
|
||||
"""
|
||||
self.wasm_ast = phasm_compile(self.phasm_ast)
|
||||
|
||||
def optimise_wasm_ast(self) -> None:
|
||||
"""
|
||||
Optimises the WebAssembly AST
|
||||
"""
|
||||
removeunusedfuncs(self.wasm_ast)
|
||||
|
||||
def compile_wat(self) -> None:
|
||||
"""
|
||||
Compiles the WebAssembly AST into WebAssembly Assembly code
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user