Had to implement both functions as arguments and type place holders (variables) for type constructors. Had to implement functions as a type as well. Still have to figure out how to pass functions around.
226 lines
6.7 KiB
Python
226 lines
6.7 KiB
Python
"""
|
|
Runners to help run WebAssembly code on various interpreters
|
|
"""
|
|
import ctypes
|
|
from typing import Any, Callable, Dict, Iterable, Optional, TextIO
|
|
|
|
import wasmtime
|
|
|
|
from phasm import ourlang, wasm
|
|
from phasm.compiler import phasm_compile
|
|
from phasm.parser import phasm_parse
|
|
from phasm.type3.entry import phasm_type3
|
|
|
|
Imports = Optional[Dict[str, Callable[[Any], Any]]]
|
|
|
|
class RunnerBase:
|
|
"""
|
|
Base class
|
|
"""
|
|
phasm_code: str
|
|
phasm_ast: ourlang.Module
|
|
wasm_ast: wasm.Module
|
|
wasm_asm: str
|
|
wasm_bin: bytes
|
|
|
|
def __init__(self, phasm_code: str) -> None:
|
|
self.phasm_code = phasm_code
|
|
|
|
def dump_phasm_code(self, textio: TextIO) -> None:
|
|
"""
|
|
Dumps the input Phasm code for debugging
|
|
"""
|
|
_dump_code(textio, self.phasm_code)
|
|
|
|
def parse(self, verbose: bool = True) -> None:
|
|
"""
|
|
Parses the Phasm code into an AST
|
|
"""
|
|
self.phasm_ast = phasm_parse(self.phasm_code)
|
|
phasm_type3(self.phasm_ast, verbose=verbose)
|
|
|
|
def compile_ast(self) -> None:
|
|
"""
|
|
Compiles the Phasm AST into an WebAssembly AST
|
|
"""
|
|
self.wasm_ast = phasm_compile(self.phasm_ast)
|
|
|
|
def compile_wat(self) -> None:
|
|
"""
|
|
Compiles the WebAssembly AST into WebAssembly Assembly code
|
|
"""
|
|
self.wasm_asm = self.wasm_ast.to_wat()
|
|
|
|
def dump_wasm_wat(self, textio: TextIO) -> None:
|
|
"""
|
|
Dumps the intermediate WebAssembly Assembly code for debugging
|
|
"""
|
|
_dump_code(textio, self.wasm_asm)
|
|
|
|
def compile_wasm(self) -> None:
|
|
"""
|
|
Compiles the WebAssembly AST into WebAssembly Binary
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def interpreter_setup(self) -> None:
|
|
"""
|
|
Sets up the interpreter
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def interpreter_load(self, imports: Imports = None) -> None:
|
|
"""
|
|
Loads the code into the interpreter
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
|
|
"""
|
|
Writes into the interpreters memory
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def interpreter_read_memory(self, offset: int, length: int) -> bytes:
|
|
"""
|
|
Reads from the interpreters memory
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def interpreter_dump_memory(self, textio: TextIO) -> None:
|
|
"""
|
|
Dumps the interpreters memory for debugging
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def call(self, function: str, *args: Any) -> Any:
|
|
"""
|
|
Calls the given function with the given arguments, returning the result
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
class RunnerWasmtime(RunnerBase):
|
|
"""
|
|
Implements a runner for wasmtime
|
|
|
|
See https://pypi.org/project/wasmtime/
|
|
"""
|
|
store: wasmtime.Store
|
|
module: wasmtime.Module
|
|
instance: wasmtime.Instance
|
|
|
|
@classmethod
|
|
def func2type(cls, func: Callable[[Any], Any]) -> wasmtime.FuncType:
|
|
params: list[wasmtime.ValType] = []
|
|
|
|
code = func.__code__
|
|
for idx in range(code.co_argcount):
|
|
varname = code.co_varnames[idx]
|
|
vartype = func.__annotations__[varname]
|
|
|
|
if vartype is int:
|
|
params.append(wasmtime.ValType.i32())
|
|
elif vartype is float:
|
|
params.append(wasmtime.ValType.f32())
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
results: list[wasmtime.ValType] = []
|
|
if func.__annotations__['return'] is None:
|
|
pass # No return value
|
|
elif func.__annotations__['return'] is int:
|
|
results.append(wasmtime.ValType.i32())
|
|
elif func.__annotations__['return'] is float:
|
|
results.append(wasmtime.ValType.f32())
|
|
else:
|
|
raise NotImplementedError('Return type', func.__annotations__['return'])
|
|
|
|
return wasmtime.FuncType(params, results)
|
|
|
|
def interpreter_setup(self) -> None:
|
|
self.store = wasmtime.Store()
|
|
|
|
def interpreter_load(self, imports: Optional[Dict[str, Callable[[Any], Any]]] = None) -> None:
|
|
functions: list[wasmtime.Func] = []
|
|
|
|
if imports is not None:
|
|
functions = [
|
|
wasmtime.Func(self.store, self.__class__.func2type(f), f)
|
|
for f in imports.values()
|
|
]
|
|
|
|
self.module = wasmtime.Module(self.store.engine, self.wasm_asm)
|
|
self.instance = wasmtime.Instance(self.store, self.module, functions)
|
|
|
|
def interpreter_write_memory(self, offset: int, data: Iterable[int]) -> None:
|
|
exports = self.instance.exports(self.store)
|
|
memory = exports['memory']
|
|
assert isinstance(memory, wasmtime.Memory) # type hint
|
|
|
|
data_ptr = memory.data_ptr(self.store)
|
|
data_len = memory.data_len(self.store)
|
|
|
|
idx = offset
|
|
for byt in data:
|
|
assert idx < data_len
|
|
data_ptr[idx] = ctypes.c_ubyte(byt)
|
|
idx += 1
|
|
|
|
def interpreter_read_memory(self, offset: int, length: int) -> bytes:
|
|
exports = self.instance.exports(self.store)
|
|
memory = exports['memory']
|
|
assert isinstance(memory, wasmtime.Memory) # type hint
|
|
|
|
data_ptr = memory.data_ptr(self.store)
|
|
data_len = memory.data_len(self.store)
|
|
|
|
raw = ctypes.string_at(data_ptr, data_len)
|
|
return raw[offset:offset + length]
|
|
|
|
def interpreter_dump_memory(self, textio: TextIO) -> None:
|
|
exports = self.instance.exports(self.store)
|
|
memory = exports['memory']
|
|
assert isinstance(memory, wasmtime.Memory) # type hint
|
|
|
|
data_ptr = memory.data_ptr(self.store)
|
|
data_len = memory.data_len(self.store)
|
|
|
|
_dump_memory(textio, ctypes.string_at(data_ptr, data_len))
|
|
|
|
def call(self, function: str, *args: Any) -> Any:
|
|
exports = self.instance.exports(self.store)
|
|
func = exports[function]
|
|
assert isinstance(func, wasmtime.Func)
|
|
|
|
return func(self.store, *args)
|
|
|
|
def _dump_memory(textio: TextIO, mem: bytes) -> None:
|
|
line_width = 16
|
|
|
|
prev_line = None
|
|
skip = False
|
|
for idx in range(0, len(mem), line_width):
|
|
line = ''
|
|
for idx2 in range(0, line_width):
|
|
line += f'{mem[idx + idx2]:02X}'
|
|
if idx2 % 2 == 1:
|
|
line += ' '
|
|
|
|
if prev_line == line:
|
|
if not skip:
|
|
textio.write('**\n')
|
|
skip = True
|
|
else:
|
|
textio.write(f'{idx:08x} {line}\n')
|
|
|
|
prev_line = line
|
|
|
|
def _dump_code(textio: TextIO, text: str) -> None:
|
|
line_list = text.split('\n')
|
|
line_no_width = len(str(len(line_list)))
|
|
for line_no, line_txt in enumerate(line_list):
|
|
textio.write('{} {}\n'.format(
|
|
str(line_no + 1).zfill(line_no_width),
|
|
line_txt,
|
|
))
|