phasm/tests/integration/runners.py
2022-09-16 16:43:40 +02:00

330 lines
9.9 KiB
Python

"""
Runners to help run WebAssembly code on various interpreters
"""
from typing import Any, Callable, Dict, Iterable, Optional, TextIO
import ctypes
import io
import warnings
import pywasm.binary
import wasm3
import wasmer
import wasmtime
from phasm.compiler import phasm_compile
from phasm.parser import phasm_parse
from phasm.typer import phasm_type
from phasm import ourlang
from phasm import wasm
class RunnerBase:
"""
Base class
"""
phasm_code: str
phasm_ast: ourlang.Module
wasm_ast: wasm.Module
wasm_asm: str
wasm_bin: bytes
def __init__(self, phasm_code: str) -> None:
self.phasm_code = phasm_code
def dump_phasm_code(self, textio: TextIO) -> None:
"""
Dumps the input Phasm code for debugging
"""
_dump_code(textio, self.phasm_code)
def parse(self) -> None:
"""
Parses the Phasm code into an AST
"""
self.phasm_ast = phasm_parse(self.phasm_code)
try:
phasm_type(self.phasm_ast)
except NotImplementedError as exc:
warnings.warn(f'phasm_type throws an NotImplementedError on this test: {exc}')
def compile_ast(self) -> None:
"""
Compiles the Phasm AST into an WebAssembly AST
"""
self.wasm_ast = phasm_compile(self.phasm_ast)
def compile_wat(self) -> None:
"""
Compiles the WebAssembly AST into WebAssembly Assembly code
"""
self.wasm_asm = self.wasm_ast.to_wat()
def dump_wasm_wat(self, textio: TextIO) -> None:
"""
Dumps the intermediate WebAssembly Assembly code for debugging
"""
_dump_code(textio, self.wasm_asm)
def compile_wasm(self) -> None:
"""
Compiles the WebAssembly AST into WebAssembly Binary
"""
self.wasm_bin = wasmer.wat2wasm(self.wasm_asm)
def interpreter_setup(self) -> None:
"""
Sets up the interpreter
"""
raise NotImplementedError
def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None:
"""
Loads the code into the interpreter
"""
raise NotImplementedError
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
"""
Writes into the interpreters memory
"""
raise NotImplementedError
def interpreter_read_memory(self, offset: int, length: int) -> bytes:
"""
Reads from the interpreters memory
"""
raise NotImplementedError
def interpreter_dump_memory(self, textio: TextIO) -> None:
"""
Dumps the interpreters memory for debugging
"""
raise NotImplementedError
def call(self, function: str, *args: Any) -> Any:
"""
Calls the given function with the given arguments, returning the result
"""
raise NotImplementedError
class RunnerPywasm(RunnerBase):
"""
Implements a runner for pywasm
See https://pypi.org/project/pywasm/
"""
module: pywasm.binary.Module
runtime: pywasm.Runtime
def interpreter_setup(self) -> None:
# Nothing to set up
pass
def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None:
if imports is not None:
raise NotImplementedError
bytesio = io.BytesIO(self.wasm_bin)
self.module = pywasm.binary.Module.from_reader(bytesio)
self.runtime = pywasm.Runtime(self.module, {}, None)
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
for idx, byt in enumerate(data):
self.runtime.store.memory_list[0].data[offset + idx] = byt
def interpreter_read_memory(self, offset: int, length: int) -> bytes:
return self.runtime.store.memory_list[0].data[offset:length]
def interpreter_dump_memory(self, textio: TextIO) -> None:
_dump_memory(textio, self.runtime.store.memory_list[0].data)
def call(self, function: str, *args: Any) -> Any:
return self.runtime.exec(function, [*args])
class RunnerPywasm3(RunnerBase):
"""
Implements a runner for pywasm3
See https://pypi.org/project/pywasm3/
"""
env: wasm3.Environment
rtime: wasm3.Runtime
mod: wasm3.Module
def interpreter_setup(self) -> None:
self.env = wasm3.Environment()
self.rtime = self.env.new_runtime(1024 * 1024)
def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None:
if imports is not None:
raise NotImplementedError
self.mod = self.env.parse_module(self.wasm_bin)
self.rtime.load(self.mod)
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
memory = self.rtime.get_memory(0)
for idx, byt in enumerate(data):
memory[offset + idx] = byt # type: ignore
def interpreter_read_memory(self, offset: int, length: int) -> bytes:
memory = self.rtime.get_memory(0)
return memory[offset:length].tobytes()
def interpreter_dump_memory(self, textio: TextIO) -> None:
_dump_memory(textio, self.rtime.get_memory(0))
def call(self, function: str, *args: Any) -> Any:
return self.rtime.find_function(function)(*args)
class RunnerWasmtime(RunnerBase):
"""
Implements a runner for wasmtime
See https://pypi.org/project/wasmtime/
"""
store: wasmtime.Store
module: wasmtime.Module
instance: wasmtime.Instance
def interpreter_setup(self) -> None:
self.store = wasmtime.Store()
def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None:
if imports is not None:
raise NotImplementedError
self.module = wasmtime.Module(self.store.engine, self.wasm_bin)
self.instance = wasmtime.Instance(self.store, self.module, [])
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
exports = self.instance.exports(self.store)
memory = exports['memory']
assert isinstance(memory, wasmtime.Memory) # type hint
data_ptr = memory.data_ptr(self.store)
data_len = memory.data_len(self.store)
idx = offset
for byt in data:
assert idx < data_len
data_ptr[idx] = ctypes.c_ubyte(byt)
idx += 1
def interpreter_read_memory(self, offset: int, length: int) -> bytes:
exports = self.instance.exports(self.store)
memory = exports['memory']
assert isinstance(memory, wasmtime.Memory) # type hint
data_ptr = memory.data_ptr(self.store)
data_len = memory.data_len(self.store)
raw = ctypes.string_at(data_ptr, data_len)
return raw[offset:length]
def interpreter_dump_memory(self, textio: TextIO) -> None:
exports = self.instance.exports(self.store)
memory = exports['memory']
assert isinstance(memory, wasmtime.Memory) # type hint
data_ptr = memory.data_ptr(self.store)
data_len = memory.data_len(self.store)
_dump_memory(textio, ctypes.string_at(data_ptr, data_len))
def call(self, function: str, *args: Any) -> Any:
exports = self.instance.exports(self.store)
func = exports[function]
assert isinstance(func, wasmtime.Func)
return func(self.store, *args)
class RunnerWasmer(RunnerBase):
"""
Implements a runner for wasmer
See https://pypi.org/project/wasmer/
"""
# pylint: disable=E1101
store: wasmer.Store
module: wasmer.Module
instance: wasmer.Instance
def interpreter_setup(self) -> None:
self.store = wasmer.Store()
def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None:
import_object = wasmer.ImportObject()
if imports:
import_object.register('imports', {
k: wasmer.Function(self.store, v)
for k, v in (imports or {}).items()
})
self.module = wasmer.Module(self.store, self.wasm_bin)
self.instance = wasmer.Instance(self.module, import_object)
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
exports = self.instance.exports
memory = getattr(exports, 'memory')
assert isinstance(memory, wasmer.Memory)
view = memory.uint8_view(offset)
for idx, byt in enumerate(data):
view[idx] = byt
def interpreter_read_memory(self, offset: int, length: int) -> bytes:
exports = self.instance.exports
memory = getattr(exports, 'memory')
assert isinstance(memory, wasmer.Memory)
view = memory.uint8_view(offset)
return bytes(view[offset:length])
def interpreter_dump_memory(self, textio: TextIO) -> None:
exports = self.instance.exports
memory = getattr(exports, 'memory')
assert isinstance(memory, wasmer.Memory)
view = memory.uint8_view()
_dump_memory(textio, view) # type: ignore
def call(self, function: str, *args: Any) -> Any:
exports = self.instance.exports
func = getattr(exports, function)
return func(*args)
def _dump_memory(textio: TextIO, mem: bytes) -> None:
line_width = 16
prev_line = None
skip = False
for idx in range(0, len(mem), line_width):
line = ''
for idx2 in range(0, line_width):
line += f'{mem[idx + idx2]:02X}'
if idx2 % 2 == 1:
line += ' '
if prev_line == line:
if not skip:
textio.write('**\n')
skip = True
else:
textio.write(f'{idx:08x} {line}\n')
prev_line = line
def _dump_code(textio: TextIO, text: str) -> None:
line_list = text.split('\n')
line_no_width = len(str(len(line_list)))
for line_no, line_txt in enumerate(line_list):
textio.write('{} {}\n'.format(
str(line_no + 1).zfill(line_no_width),
line_txt,
))