diff --git a/phasmplatform/common/value.py b/phasmplatform/common/value.py new file mode 100644 index 0000000..1ba44ea --- /dev/null +++ b/phasmplatform/common/value.py @@ -0,0 +1,26 @@ +from typing import Any, Generic, TypeVar + +T = TypeVar('T') + + +class BaseValue(Generic[T]): + __slots__ = ('data', ) + + data: T + + def __init__(self, data: T) -> None: + self.data = data + + def __eq__(self, other: Any) -> bool: + return self.__class__ is other.__class__ and self.data == other.data + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({repr(self.data)})' + + +class UntypedValue(BaseValue[Any]): + pass + + +class BytesValue(BaseValue[bytes]): + pass diff --git a/phasmplatform/worker/__main__.py b/phasmplatform/worker/__main__.py index 0788256..5d4e366 100644 --- a/phasmplatform/worker/__main__.py +++ b/phasmplatform/worker/__main__.py @@ -1,7 +1,10 @@ +from typing import Any + import sys from phasmplatform.common.config import from_toml from phasmplatform.common.router import StdOutRouter +from phasmplatform.common.value import BaseValue, BytesValue from .runners.base import BaseRunner from .runners.wasmtime import WasmTimeRunner @@ -27,6 +30,15 @@ def main() -> int: foo.handle_message(namespace, topic, kind, body) + inp = BytesValue(b'Hello, world!') + print('inp', inp) + + def on_respond(out: BaseValue[Any]) -> None: + print('out', out) + assert out == inp + + foo.do_call('echo', [inp], on_respond) + return 0 diff --git a/phasmplatform/worker/runners/wasmtime.py b/phasmplatform/worker/runners/wasmtime.py index 6e937f8..45de585 100644 --- a/phasmplatform/worker/runners/wasmtime.py +++ b/phasmplatform/worker/runners/wasmtime.py @@ -1,9 +1,12 @@ +from typing import Any, Callable, List, Union + import ctypes import struct import wasmtime from phasmplatform.common.router import BaseRouter +from phasmplatform.common.value import BaseValue, BytesValue, UntypedValue from .base import BaseRunner @@ -63,6 +66,21 @@ class WasmTimeRunner(BaseRunner): return raw[ptr + 4:ptr + 4 + length] + def convert_value(self, val: BaseValue[Any]) -> Union[int, float]: + if isinstance(val, BytesValue): + return self.alloc_bytes(val.data) + + raise NotImplementedError(val) + + def do_call(self, method_name: str, args: List[BaseValue[Any]], callback: Callable[[BaseValue[Any]], None]) -> None: + method = self.exports[method_name] + assert isinstance(method, wasmtime.Func) + + act_args = [self.convert_value(x) for x in args] + result = method(self.store, *act_args) + + callback(UntypedValue(result)) # TODO: This returns a bytes pointer, but we can't detect that in advance + def handle_message(self, namespace: bytes, topic: bytes, kind: bytes, body: bytes) -> None: namespace_ptr = self.alloc_bytes(namespace) topic_ptr = self.alloc_bytes(topic)