Skip to content
Snippets Groups Projects
Commit 9e25fd1d authored by Florian Schröder's avatar Florian Schröder
Browse files

Refactor websocket handling and improve test module

Websocket handling was significantly refactored to improve management of active websocket connections. Now, each connection is assigned to a concurrently running task. A timeout component was introduced within `recv_messages` function, which handles failed message receipt better by catching `CancelledError` and `TimeoutError` effectively. This enhance reliability of message receipt process.

Furthermore, a `terminate` method was added which efficiently closes all active websocket connections. In the test module, print statements were introduced to trace the connection process in the 'else' block, and the 'send' function call was moved into this block as well.
parent ceb15451
No related branches found
No related tags found
No related merge requests found
Pipeline #45036 failed
......@@ -39,11 +39,11 @@ class TestModule(Module, ModuleInfo):
async def initialize(self):
if self.config.connect_serve == "serve":
await self.websocket.create("my", "localhost", self.handle_receive, 8765)
await asyncio.sleep(0.5)
await self.websocket.send("my", "0")
else:
await asyncio.sleep(0.1)
await self.websocket.connect("my", "ws://localhost:8765", self.handle_receive)
await asyncio.sleep(0.1)
await self.websocket.send("my", "0")
async def handle_receive(self, ref, message):
try:
......
import asyncio
import traceback
from asyncio.exceptions import CancelledError, TimeoutError
from functools import partial
from typing import Type, Callable, Awaitable, Any
from async_timeout import timeout
import websockets
from semantic_version import Version, SimpleSpec
......@@ -30,20 +32,23 @@ class WebsocketExtension(Extension):
return WebsocketExtensionSetup
async def create(self, ref: str, url: str, callback: Callable[[str, Any], Awaitable], port: int):
websocket = await websockets.serve(partial(self.recv_messages, callback=callback, ref=ref), url, port)
self.active_websockets[ref] = (True, websocket)
await websockets.serve(partial(self.recv_messages, callback=callback, ref=ref, active_websockets=self.active_websockets), url, port)
async def connect(self, ref: str, uri: str, callback: Callable[[str, Any], Awaitable]):
websocket = await websockets.connect(uri)
self.active_websockets[ref] = (False, websocket)
await self.recv_messages(websocket=websocket, callback=callback, ref=ref)
asyncio.create_task(self.recv_messages(websocket=websocket, callback=callback, ref=ref, active_websockets=self.active_websockets))
@staticmethod
async def recv_messages(websocket, callback: Callable[[str, Any], Awaitable], ref: str):
async def recv_messages(websocket, callback: Callable[[str, Any], Awaitable], ref: str, active_websockets):
active_websockets[ref] = websocket
try:
while True:
message = await websocket.recv()
await callback(ref, message)
try:
async with timeout(0.1):
message = await websocket.recv()
await callback(ref, message)
except (CancelledError, TimeoutError):
await asyncio.sleep(0.05)
except Exception as e:
traceback.print_exception(e)
finally:
......@@ -51,13 +56,12 @@ class WebsocketExtension(Extension):
async def send(self, ref: str, message: str):
if ref in self.active_websockets:
is_server, websocket = self.active_websockets[ref]
if is_server:
for protocol in websocket.websockets:
await protocol.send(message)
else:
await websocket.send(message)
websocket = self.active_websockets[ref]
await websocket.send(message)
def terminate(self, *args, **kwargs):
for websocket in self.active_websockets.values():
websocket.close()
class WebsocketExtensionSetup(ExtensionSetup):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment