162 lines
5.4 KiB
Python
162 lines
5.4 KiB
Python
from typing import Any, Callable, Dict, List, Union
|
|
|
|
import ctypes
|
|
import functools
|
|
import struct
|
|
from queue import Empty, Queue
|
|
|
|
import wasmtime
|
|
|
|
from phasmplatform.common import valuetype
|
|
from phasmplatform.common.container import Container
|
|
from phasmplatform.common.method import Method
|
|
from phasmplatform.common.methodcall import MethodCall, MethodCallError, MethodNotFoundError
|
|
from phasmplatform.common.router import MethodCallRouterInterface
|
|
from phasmplatform.common.value import Value
|
|
from phasmplatform.common.valuetype import ValueType
|
|
|
|
from .base import BaseRunner, WasmValue
|
|
|
|
|
|
class WasmTimeRunner(BaseRunner):
|
|
__slots__ = ('store', 'module', 'instance', 'exports')
|
|
|
|
def __init__(self, method_call_router: MethodCallRouterInterface, container: Container, container_log: Callable[[str], None]) -> None:
|
|
super().__init__(method_call_router, container, container_log)
|
|
|
|
with open(f'./{container.image.path}', 'rb') as fil: # TODO: ImageLoader?
|
|
wasm_bin = fil.read()
|
|
|
|
self.store = wasmtime.Store()
|
|
self.module = wasmtime.Module(self.store.engine, wasm_bin)
|
|
|
|
import_map: Dict[str, wasmtime.Func] = {
|
|
f'{imprt_service}.{imprt_method.name}': wasmtime.Func(
|
|
self.store,
|
|
build_func_type(imprt_method),
|
|
functools.partial(self.send_service_call, imprt_service, imprt_method)
|
|
)
|
|
for (imprt_service, imprt_method, ) in container.image.imports
|
|
}
|
|
|
|
import_map['prelude.log_bytes'] = wasmtime.Func(
|
|
self.store,
|
|
wasmtime.FuncType([wasmtime.ValType.i32()], []),
|
|
functools.partial(self.log_bytes),
|
|
)
|
|
|
|
# Make sure the given import lists order matches the one given by wasmtime
|
|
# Otherwise, wasmtime can't match them up.
|
|
imports: List[wasmtime.Func] = [
|
|
import_map[f'{imprt.module}.{imprt.name}']
|
|
for imprt in self.module.imports
|
|
]
|
|
|
|
self.instance = wasmtime.Instance(self.store, self.module, imports)
|
|
|
|
self.exports = self.instance.exports(self.store)
|
|
|
|
def alloc_bytes(self, data: bytes) -> int:
|
|
memory = self.exports['memory']
|
|
assert isinstance(memory, wasmtime.Memory) # type hint
|
|
|
|
data_ptr = memory.data_ptr(self.store)
|
|
data_len = memory.data_len(self.store)
|
|
|
|
alloc_bytes = self.exports['stdlib.types.__alloc_bytes__']
|
|
assert isinstance(alloc_bytes, wasmtime.Func)
|
|
|
|
ptr = alloc_bytes(self.store, len(data))
|
|
assert isinstance(ptr, int) # type hint
|
|
|
|
idx = ptr + 4 # Skip the header from header from __alloc_bytes__
|
|
for byt in data:
|
|
assert idx < data_len
|
|
data_ptr[idx] = ctypes.c_ubyte(byt)
|
|
idx += 1
|
|
|
|
return ptr
|
|
|
|
def read_bytes(self, ptr: int) -> bytes:
|
|
memory = self.exports['memory']
|
|
assert isinstance(memory, wasmtime.Memory) # type hint
|
|
|
|
data_ptr = memory.data_ptr(self.store)
|
|
data_len = memory.data_len(self.store)
|
|
|
|
raw = ctypes.string_at(data_ptr, data_len)
|
|
|
|
length, = struct.unpack('<I', raw[ptr:ptr + 4]) # Header prefixed by __alloc_bytes__
|
|
|
|
return raw[ptr + 4:ptr + 4 + length]
|
|
|
|
def do_call(self, call: MethodCall) -> Union[Value, MethodCallError]:
|
|
try:
|
|
wasm_method = self.exports[call.method.name]
|
|
except KeyError:
|
|
return MethodNotFoundError()
|
|
|
|
assert isinstance(wasm_method, wasmtime.Func)
|
|
|
|
act_args = [self.value_to_wasm(x) for x in call.args]
|
|
result = wasm_method(self.store, *act_args)
|
|
assert result is None or isinstance(result, (int, float, )) # type hint
|
|
|
|
return self.value_from_wasm(call.method.return_type, result)
|
|
|
|
def send_service_call(self, service: str, method: Method, *args: Any) -> WasmValue:
|
|
assert len(method.arg_types) == len(args) # type hint
|
|
|
|
call_args = [
|
|
self.value_from_wasm(x, y)
|
|
for x, y in zip(method.arg_types, args)
|
|
]
|
|
|
|
queue: Queue[Value] = Queue(maxsize=1)
|
|
|
|
def on_result(res: Union[Value, MethodCallError]) -> None:
|
|
if isinstance(res, Value):
|
|
queue.put(res)
|
|
else:
|
|
raise Exception(res)
|
|
|
|
call = MethodCall(method, call_args)
|
|
|
|
self.method_call_router.route_call(service, call, self.container, on_result)
|
|
|
|
try:
|
|
value = queue.get(block=True, timeout=10)
|
|
except Empty:
|
|
raise Exception('Did not receive value from remote call') # TODO
|
|
|
|
return self.value_to_wasm(value)
|
|
|
|
def log_bytes(self, data_ptr: int) -> None:
|
|
value = self.value_from_wasm(valuetype.bytes, data_ptr)
|
|
|
|
self.container_log[0](repr(value.data))
|
|
|
|
|
|
def build_func_type(method: Method) -> wasmtime.FuncType:
|
|
if method.return_type is valuetype.none:
|
|
returns = []
|
|
else:
|
|
returns = [build_wasm_type(method.return_type)]
|
|
|
|
args = []
|
|
for arg_type in method.arg_types:
|
|
assert arg_type is not valuetype.none # type hint
|
|
args.append(build_wasm_type(arg_type))
|
|
|
|
return wasmtime.FuncType(args, returns)
|
|
|
|
|
|
def build_wasm_type(value_type: ValueType) -> wasmtime.ValType:
|
|
if value_type is valuetype.u32:
|
|
return wasmtime.ValType.i32() # Signed-ness is in the operands
|
|
|
|
if value_type is valuetype.bytes:
|
|
return wasmtime.ValType.i32() # Bytes are passed as pointer
|
|
|
|
raise NotImplementedError
|