diff --git a/phasm/transformers/__init__.py b/phasm/transformers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phasm/transformers/wasm/__init__.py b/phasm/transformers/wasm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/phasm/transformers/wasm/phasmplatform/__init__.py b/phasm/transformers/wasm/phasmplatform/__init__.py new file mode 100644 index 0000000..c15dfaa --- /dev/null +++ b/phasm/transformers/wasm/phasmplatform/__init__.py @@ -0,0 +1,6 @@ +from . import splitoncall + +from phasm.wasm import Module + +def transform(module: Module) -> None: + splitoncall.transform(module) diff --git a/phasm/transformers/wasm/phasmplatform/splitoncall.py b/phasm/transformers/wasm/phasmplatform/splitoncall.py new file mode 100644 index 0000000..2a2a587 --- /dev/null +++ b/phasm/transformers/wasm/phasmplatform/splitoncall.py @@ -0,0 +1,63 @@ +from typing import List + +from phasm.wasm import ( + Function, Module, Statement +) + +def is_imported(module: Module, name: str) -> bool: + for imprt in module.imports: + if name == imprt.intname: + return True + + return False + +def split_on_call_function(module: Module, function: Function) -> List[Function]: + function_list = [] + statement_list = [] + idx = 0 + while function.statements: + stmt = function.statements.pop(0) + if not stmt.name == 'call' or is_imported(module, stmt.args[0][1:]): + statement_list.append(stmt) + continue + + function_list.append(Function( + f'{function.name}.{idx}', + None, + function.params, + function.locals, + function.result, + statement_list + [stmt] + [Statement('call', f'${function.name}.{idx + 1}')] + )) + statement_list = [] + idx += 1 + + function_list.append(Function( + f'{function.name}.{idx}', + None, + function.params, + function.locals, + function.result, + statement_list + )) + + if function.exported_name: + function_list.append(Function( + function.name + '.e', + function.exported_name, + function.params, + [], + function.result, + [ + Statement('local.get', '$' + x[0]) + for x in function.params + ] + [Statement('call', f'${function.name}.0')] + )) + + return function_list + +def transform(module: Module) -> None: + new_functions = [] + for func in module.functions: + new_functions.extend(split_on_call_function(module, func)) + module.functions = new_functions diff --git a/phasm/wasm.py b/phasm/wasm.py index 7c5a982..5635f66 100644 --- a/phasm/wasm.py +++ b/phasm/wasm.py @@ -15,6 +15,12 @@ class WatSerializable: """ raise NotImplementedError(self, 'to_wat') + def alloc_size(self) -> int: + """ + Returns how many bytes a variable of this type takes up in memory + """ + raise NotImplementedError(self, 'alloc_size') + class WasmType(WatSerializable): """ Type base class @@ -36,6 +42,9 @@ class WasmTypeInt32(WasmType): def to_wat(self) -> str: return 'i32' + def alloc_size(self) -> int: + return 4 + class WasmTypeInt64(WasmType): """ i64 value @@ -52,6 +61,9 @@ class WasmTypeFloat32(WasmType): def to_wat(self) -> str: return 'f32' + def alloc_size(self) -> int: + return 4 + class WasmTypeFloat64(WasmType): """ f64 value @@ -179,6 +191,34 @@ class ModuleMemory(WatSerializable): '(export "memory" (memory 0))\n' ) +class TableElement(WatSerializable): + """ + Represents a Web Assembly table element + """ + def __init__(self, offset: int, args: Iterable[str]) -> None: + self.offset = offset + self.args = [*args] + + def to_wat(self) -> str: + args = ' '.join(self.args) + + return f'(elem (i32.const {self.offset}) {args})' + +class Table(WatSerializable): + """ + Represents a Web Assembly table + """ + def __init__(self, size: int, typ: str, elements: List[TableElement]) -> None: + self.size = size + self.type = typ + self.elements = [*elements] + + def to_wat(self) -> str: + return ( + f'(table {self.size} {self.type})\n ' + + '\n '.join(x.to_wat() for x in self.elements) + ) + class Module(WatSerializable): """ Represents a Web Assembly module @@ -187,13 +227,15 @@ class Module(WatSerializable): self.imports: List[Import] = [] self.functions: List[Function] = [] self.memory = ModuleMemory() + self.tables: List[Table] = [] def to_wat(self) -> str: """ Generates the text version """ - return '(module\n {}\n {}\n {})\n'.format( + return '(module\n {}\n {}\n {}\n {})\n'.format( '\n '.join(x.to_wat() for x in self.imports), self.memory.to_wat(), + '\n '.join(x.to_wat() for x in self.tables), '\n '.join(x.to_wat() for x in self.functions), ) diff --git a/stubs/wasm3.pyi b/stubs/wasm3.pyi index 216412c..bdce0bf 100644 --- a/stubs/wasm3.pyi +++ b/stubs/wasm3.pyi @@ -3,6 +3,9 @@ from typing import Any, Callable class Module: ... + def link_function(self, module_name: str, function_name: str, function: Callable[[Any], Any]) -> None: + ... + class Runtime: ... diff --git a/tests/integration/runners.py b/tests/integration/runners.py index 91a29fd..d4085bc 100644 --- a/tests/integration/runners.py +++ b/tests/integration/runners.py @@ -152,12 +152,13 @@ class RunnerPywasm3(RunnerBase): self.rtime = self.env.new_runtime(1024 * 1024) def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None: - if imports is not None: - raise NotImplementedError - self.mod = self.env.parse_module(self.wasm_bin) self.rtime.load(self.mod) + if imports is not None: + for key, val in imports.items(): + self.mod.link_function('imports', key, val) + def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None: memory = self.rtime.get_memory(0) diff --git a/tests/integration/test_transformers/__init__.py b/tests/integration/test_transformers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_transformers/test_wasm/__init__.py b/tests/integration/test_transformers/test_wasm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/integration/test_transformers/test_wasm/test_phasmplatform.py b/tests/integration/test_transformers/test_wasm/test_phasmplatform.py new file mode 100644 index 0000000..e7f75eb --- /dev/null +++ b/tests/integration/test_transformers/test_wasm/test_phasmplatform.py @@ -0,0 +1,164 @@ +import sys + +from phasm import wasm + +from ...runners import RunnerPywasm, RunnerPywasm3, RunnerWasmtime, RunnerWasmer + +from phasm.transformers.wasm import phasmplatform as sut + +none = wasm.WasmTypeNone() +f32 = wasm.WasmTypeFloat32() + +def run(module, func, *, imports = None): + # runner = RunnerPywasm('-') + runner = RunnerPywasm3('-') + # runner = RunnerWasmtime('-') + runner.wasm_ast = module + runner.compile_wat() + runner.dump_wasm_wat(sys.stderr) + runner.compile_wasm() + runner.interpreter_setup() + runner.interpreter_load(imports) + return runner.call(func) + + +def return_f32_30_0() -> float: + return 30.0 + + +def make_return_f32_30_0_import(): + return wasm.Import('imports', 'return_f32_30_0', 'return_f32_30_0', [], f32) + + +def make_func1(): + """ + func1: No params, just a return value + + In Haskell: + func1 :: f32 + func1 = 30 + """ + return wasm.Function('func1', 'func1', [], [], f32, [ + wasm.Statement('f32.const', '30'), + wasm.Statement('return'), + ]) + +def make_func2(): + """ + func2: Calls imported method + + In Haskell: + return_f32_30_0 :: IO(f32) + + func2 :: IO(f32) + func2 = return_f32_30_0 + """ + return wasm.Function('func2', 'func2', [], [], f32, [ + wasm.Statement('call', '$return_f32_30_0'), + wasm.Statement('return'), + ]) + + +def test_const_return_func(): + module = wasm.Module() + module.functions.append(make_func1()) + + assert 30 == run(module, 'func1') + + sut.transform(module) + + assert 30 == run(module, 'func1') + +def test_call_imported_method(): + module = wasm.Module() + module.imports.append(make_return_f32_30_0_import()) + module.functions.append(make_func2()) + + assert 30 == run(module, 'func2', imports={'return_f32_30_0': return_f32_30_0}) + + sut.transform(module) + + assert 30 == run(module, 'func2', imports={'return_f32_30_0': return_f32_30_0}) + +def test_parameters(): + assert 0 + + # func1: No params, no return value + # func1 = func4 10 20 + func1_statements = [ + wasm.Statement('f32.const', '30'), + wasm.Statement('f32.const', '8'), + wasm.Statement('call', '$func3'), + wasm.Statement('f32.const', '0'), + wasm.Statement('call', '$log'), + ] + func1 = wasm.Function('func1', 'func1', [], [], none, func1_statements) + + + # func2: No params, with return value + # func2 = remote_value_get + func2_statements = [ + wasm.Statement('call', '$remote_value_get'), + ] + func2 = wasm.Function('func2', 'func2', [], [], f32, func2_statements) + + # With params, no return value + # func3 left right = (func4 left right) + func2 + func3_statements = [ + wasm.Statement('local.get', '$left'), + wasm.Statement('local.get', '$right'), + wasm.Statement('call', '$func4'), + wasm.Statement('call', '$func2'), + wasm.Statement('f32.add'), + wasm.Statement('call', '$log'), + ] + func3 = wasm.Function('func3', 'func3', [ + ('left', f32, ), + ('right', f32, ), + ], [], none, func3_statements) + + # func4: With params, and return value + # func4 l r = l * 2 + r + func4_statements = [ + wasm.Statement('local.get', '$left'), + wasm.Statement('f32.const', '2'), + wasm.Statement('f32.mul'), + wasm.Statement('local.get', '$right'), + wasm.Statement('f32.add'), + wasm.Statement('return'), + ] + func4 = wasm.Function('func4', None, [ + ('left', f32, ), + ('right', f32, ), + ], [], f32, func4_statements) + + module = wasm.Module() + module.imports.append(wasm.Import('imports', 'log', 'log', [('a', f32, )], none)) + module.imports.append(wasm.Import('imports', 'remote_value_get', 'remote_value_get', [], f32)) + module.functions.append(func1) + module.functions.append(func2) + module.functions.append(func3) + module.functions.append(func4) + + def my_remote_value_get() -> float: + return 19.0 + + log = [] + def my_log(a: float) -> None: + log.append(a) + + imports = {'log': my_log, 'remote_value_get': my_remote_value_get} + + run(module, 'func1', imports=imports) + + assert [87, 0] == log + + sut.transform(module) + + run(module, 'func1', imports=imports) + + assert 0 + + assert [87, 0] == log + +# def test_locals():