158 lines
5.1 KiB
Python
158 lines
5.1 KiB
Python
from typing import Any, List
|
|
|
|
import ctypes
|
|
import functools
|
|
import struct
|
|
from queue import Empty, Queue
|
|
|
|
import wasmtime
|
|
|
|
from phasmplatform.common import valuetype
|
|
from phasmplatform.common.exceptions import PhashPlatformServiceNotFound, PhashPlatformServiceMethodNotFound
|
|
from phasmplatform.common.method import Method, MethodCall, MethodCallError, MethodNotFoundError
|
|
from phasmplatform.common.router import MethodCallRouterInterface
|
|
from phasmplatform.common.service import Service, ServiceDiscoveryInterface
|
|
from phasmplatform.common.value import Value
|
|
from phasmplatform.common.valuetype import ValueType
|
|
|
|
from .base import BaseRunner, WasmValue
|
|
|
|
|
|
class WasmTimeRunner(BaseRunner):
|
|
__slots__ = ('store', 'module', 'instance', 'exports')
|
|
|
|
def __init__(
|
|
self,
|
|
service_discovery: ServiceDiscoveryInterface,
|
|
method_call_router: MethodCallRouterInterface,
|
|
wasm_bin: bytes,
|
|
) -> None:
|
|
super().__init__(service_discovery, method_call_router)
|
|
|
|
self.store = wasmtime.Store()
|
|
self.module = wasmtime.Module(self.store.engine, wasm_bin)
|
|
|
|
imports: List[wasmtime.Func] = []
|
|
for imprt in self.module.imports:
|
|
service = service_discovery.find_service(imprt.module)
|
|
if service is None:
|
|
raise PhashPlatformServiceNotFound(
|
|
f'Dependent service "{imprt.module}" not found; could not provide "{imprt.name}"'
|
|
)
|
|
|
|
assert imprt.name is not None # type hint
|
|
|
|
method = service.methods.get(imprt.name)
|
|
if method is None:
|
|
raise PhashPlatformServiceMethodNotFound(
|
|
f'Dependent service "{imprt.module}" found, but it does not provide "{imprt.name}"'
|
|
)
|
|
|
|
func = wasmtime.Func(
|
|
self.store,
|
|
build_func_type(method),
|
|
functools.partial(self.send_service_call, service, method)
|
|
)
|
|
|
|
imports.append(func)
|
|
|
|
self.instance = wasmtime.Instance(self.store, self.module, imports)
|
|
|
|
self.exports = self.instance.exports(self.store)
|
|
|
|
def alloc_bytes(self, data: bytes) -> int:
|
|
memory = self.exports['memory']
|
|
assert isinstance(memory, wasmtime.Memory) # type hint
|
|
|
|
data_ptr = memory.data_ptr(self.store)
|
|
data_len = memory.data_len(self.store)
|
|
|
|
alloc_bytes = self.exports['stdlib.types.__alloc_bytes__']
|
|
assert isinstance(alloc_bytes, wasmtime.Func)
|
|
|
|
ptr = alloc_bytes(self.store, len(data))
|
|
assert isinstance(ptr, int) # type hint
|
|
|
|
idx = ptr + 4 # Skip the header from header from __alloc_bytes__
|
|
for byt in data:
|
|
assert idx < data_len
|
|
data_ptr[idx] = ctypes.c_ubyte(byt)
|
|
idx += 1
|
|
|
|
return ptr
|
|
|
|
def read_bytes(self, ptr: int) -> bytes:
|
|
memory = self.exports['memory']
|
|
assert isinstance(memory, wasmtime.Memory) # type hint
|
|
|
|
data_ptr = memory.data_ptr(self.store)
|
|
data_len = memory.data_len(self.store)
|
|
|
|
raw = ctypes.string_at(data_ptr, data_len)
|
|
|
|
length, = struct.unpack('<I', raw[ptr:ptr + 4]) # Header prefixed by __alloc_bytes__
|
|
|
|
return raw[ptr + 4:ptr + 4 + length]
|
|
|
|
def do_call(self, call: MethodCall) -> None:
|
|
try:
|
|
wasm_method = self.exports[call.method.name]
|
|
except KeyError:
|
|
call.on_error(MethodNotFoundError())
|
|
return
|
|
|
|
assert isinstance(wasm_method, wasmtime.Func)
|
|
|
|
act_args = [self.value_to_wasm(x) for x in call.args]
|
|
result = wasm_method(self.store, *act_args)
|
|
assert result is None or isinstance(result, (int, float, )) # type hint
|
|
call.on_success(self.value_from_wasm(call.method.return_type, result))
|
|
|
|
def send_service_call(self, service: Service, method: Method, *args: Any) -> WasmValue:
|
|
assert len(method.args) == len(args) # type hint
|
|
|
|
call_args = [
|
|
self.value_from_wasm(x.value_type, y)
|
|
for x, y in zip(method.args, args)
|
|
]
|
|
|
|
queue: Queue[Value] = Queue(maxsize=1)
|
|
|
|
def on_success(val: Value) -> None:
|
|
queue.put(val)
|
|
|
|
def on_error(err: MethodCallError) -> None:
|
|
print('Error while calling', service, method, args)
|
|
|
|
call = MethodCall(method, call_args, on_success, on_error)
|
|
|
|
self.method_call_router.send_call(service, call)
|
|
|
|
try:
|
|
value = queue.get(block=True, timeout=10)
|
|
except Empty:
|
|
raise Exception() # TODO
|
|
|
|
return self.value_to_wasm(value)
|
|
|
|
|
|
def build_func_type(method: Method) -> wasmtime.FuncType:
|
|
if method.return_type is valuetype.none:
|
|
returns = []
|
|
else:
|
|
returns = [build_wasm_type(method.return_type)]
|
|
|
|
args = []
|
|
for arg in method.args:
|
|
assert arg.value_type is not valuetype.none # type hint
|
|
args.append(build_wasm_type(arg.value_type))
|
|
|
|
return wasmtime.FuncType(args, returns)
|
|
|
|
|
|
def build_wasm_type(value_type: ValueType) -> wasmtime.ValType:
|
|
if value_type is valuetype.bytes:
|
|
return wasmtime.ValType.i32() # Bytes are passed as pointer
|
|
|
|
raise NotImplementedError
|