diff --git a/phasmplatform/worker/__main__.py b/phasmplatform/worker/__main__.py index 84e1726..7c3fbb3 100644 --- a/phasmplatform/worker/__main__.py +++ b/phasmplatform/worker/__main__.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple, Union import datetime import functools @@ -19,11 +19,18 @@ from .runners.base import RunnerInterface from .runners.wasmtime import WasmTimeRunner -def runner_thread(runner: RunnerInterface, queue: Queue[MethodCall]) -> None: +class ShuttingDown(): + pass + + +def runner_thread(runner: RunnerInterface, queue: Queue[Union[MethodCall, ShuttingDown]]) -> None: while True: try: call = queue.get(block=True, timeout=1) except Empty: + continue + + if isinstance(call, ShuttingDown): break runner.do_call(call) @@ -55,12 +62,12 @@ class LocalhostRunner(RunnerInterface): class LocalhostServiceDiscovery(ServiceDiscoveryInterface): - services: Dict[str, Tuple[Service, Queue[MethodCall]]] + services: Dict[str, Tuple[Service, Queue[Union[MethodCall, ShuttingDown]]]] def __init__(self) -> None: self.services = {} - def register_service(self, service: Service, queue: Queue[MethodCall]) -> None: + def register_service(self, service: Service, queue: Queue[Union[MethodCall, ShuttingDown]]) -> None: self.services[service.name] = (service, queue, ) def find_service(self, name: str) -> Optional[Service]: @@ -105,9 +112,9 @@ def main() -> int: # TODO: Replace the stuff below with the loading from the example state - localhost_queue: Queue[MethodCall] = Queue() - echo_client_queue: Queue[MethodCall] = Queue() - echo_server_queue: Queue[MethodCall] = Queue() + localhost_queue: Queue[Union[MethodCall, ShuttingDown]] = Queue() + echo_client_queue: Queue[Union[MethodCall, ShuttingDown]] = Queue() + echo_server_queue: Queue[Union[MethodCall, ShuttingDown]] = Queue() service_discovery = LocalhostServiceDiscovery() method_call_router = LocalhostMethodCallRouter(config.worker_config, service_discovery) @@ -148,7 +155,15 @@ def main() -> int: echo_client_thread.start() echo_server_thread.start() - time.sleep(3) + try: + while 1: + time.sleep(1) + except KeyboardInterrupt: + pass + + localhost_queue.put(ShuttingDown()) + echo_client_queue.put(ShuttingDown()) + echo_server_queue.put(ShuttingDown()) return 0