From 11aa86cd495d89a225ec7c3c78be254b90e7c0e0 Mon Sep 17 00:00:00 2001 From: percyjw-2 <joris.wachsmuth@gmx.de> Date: Sun, 24 Jul 2022 14:25:54 +0200 Subject: [PATCH] Sockets are now properly shutdown Master decision-making is now fixed ConnectionManager has now ContextManager support --- example/ExampleUsage.py | 23 ++++++----- swarm/ConnectionManager.py | 69 +++++++++++++++++++++++--------- tests/test_connection_manager.py | 10 ++--- 3 files changed, 66 insertions(+), 36 deletions(-) diff --git a/example/ExampleUsage.py b/example/ExampleUsage.py index e57c864..1535b0c 100644 --- a/example/ExampleUsage.py +++ b/example/ExampleUsage.py @@ -27,15 +27,18 @@ if __name__ == "__main__": exit(-2) ip_source, port_source = parse_ip(ips[index]) - connMan = ConnectionManager(addr=ip_source, port=port_source) - connMan.connect(list(ips)) - try: - while True: - print("\033[H\033[J", end="") - print(f"current master: {connMan.get_current_master()}") - print(f"Is Client Master? {connMan.get_current_master() == ips[index]}") - sleep(1) - except KeyboardInterrupt: - connMan.disconnect() + connManInit = ConnectionManager(addr=ip_source, port=port_source, ip_list=ips) + + with connManInit as connMan: + try: + while True: + print("\033[H\033[J", end="") + print(f"current master: {connMan.get_current_master()}") + print(f"client address: {ips[index]}") + print(f"Is Client Master? {connMan.get_current_master() == (ip_source, port_source)}") + print(f"connected IPs: {connMan.get_current_addresses()}") + sleep(1) + except KeyboardInterrupt: + print("exiting...") except yaml.YAMLError as exc: print(exc) diff --git a/swarm/ConnectionManager.py b/swarm/ConnectionManager.py index 0fb0a52..7d7499f 100644 --- a/swarm/ConnectionManager.py +++ b/swarm/ConnectionManager.py @@ -5,6 +5,7 @@ import threading from enum import Enum import re from functools import partial +from typing import List, Callable, Optional, Union, Tuple try: from time import time_ns @@ -33,13 +34,17 @@ class InvalidIPString(Exception): pass +class NotInContextManagerMode(Exception): + pass + + class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler): def __init__(self, connection_manager, *args, **kwargs): self.connection_manager = connection_manager super().__init__(*args, **kwargs) # match statement is not available in Python 3.6 :( - def announce(self, launch_time, addr): + def announce(self, launch_time: str, addr: str): address_parsed = _string_to_ip_and_port(addr) self.connection_manager.sockets[address_parsed] = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.connection_manager.sockets[address_parsed].connect(address_parsed) @@ -57,7 +62,6 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler): def update_launch_time(self, launch_time, addr): address_parsed = _string_to_ip_and_port(addr) self.connection_manager.connectedIPs[address_parsed] = int(launch_time) - self.send_message(StandardMessages.ACKNOWLEDGED.value) def heartbeat(self): self.send_message(StandardMessages.ACKNOWLEDGED.value) @@ -80,7 +84,11 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler): def handle(self): while not self.connection_manager.stop_socketserver: - msg_recvd = str(self.request.recv(self.connection_manager._buffer_size), "utf-8").lower() + msg_recvd = "" + try: + msg_recvd = str(self.request.recv(self.connection_manager._buffer_size), "utf-8").lower() + except ConnectionError: + pass if not msg_recvd: break msg_split = msg_recvd.split(":") @@ -101,12 +109,12 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler): else: self.default_case(msg_recvd) - def send_message(self, message): + def send_message(self, message: Union[str, int]): message = str(message).encode("utf-8") self.request.sendall(message) -def _string_to_ip_and_port(message): +def _string_to_ip_and_port(message: str) -> Tuple[str, int]: valid_ipv4 = re.compile(r"^(\d?\d?\d.){3}\d?\d?\d:(\d?){4}\d$") valid_ipv6 = re.compile(r"^([a-f\d:]+:+)+[a-f\d]+:(\d?){4}\d$") valid_address = re.compile(r"^(localhost)|(\*+.\*):(\d?){4}\d$") @@ -124,11 +132,15 @@ class ConnectionManager: addr="localhost", port=6969, heartbeat=1, - buffer_size=1024): + buffer_size=1024, + ip_list=None): + if ip_list is None: + ip_list = [] self._addr = addr self._port = port self._heartbeat = heartbeat self._buffer_size = buffer_size + self._ip_list: List[str] = ip_list self.creation_time = time_ns() @@ -142,12 +154,20 @@ class ConnectionManager: self.socketServer = None self.listeners = [] + def __enter__(self): + if self._ip_list is None: + raise NotInContextManagerMode("An IP List needs to be provided to the Constructor to use this class " + "with the 'with' keyword.") + self.connect(self._ip_list) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.disconnect() + def __del__(self): self.disconnect() - for sock in self.sockets.values(): - sock.close() - def _launch_socket_server(self, address, request_handler=socketserver.BaseRequestHandler): + def _launch_socket_server(self, address: Tuple[str, int], request_handler=socketserver.BaseRequestHandler): with socketserver.ThreadingTCPServer(address, request_handler) as server: self.socketServer = server server.serve_forever() @@ -161,20 +181,24 @@ class ConnectionManager: received = str(self.sockets[address].recv(self._buffer_size), "utf-8") if received.lower() != StandardMessages.ACKNOWLEDGED.value: to_be_deleted.append(address) - except ConnectionAbortedError: + except ConnectionError: to_be_deleted.append(address) for address in to_be_deleted: self.connectedIPs.pop(address) self.sockets[address].close() self.sockets.pop(address) if self.master_addr not in self.connectedIPs.keys() and self.master_addr != (self._addr, self._port): - if len(self.connectedIPs) > 0: + if len(self.connectedIPs) == 0: self.master_addr = (self._addr, self._port) else: - self.master_addr = min(self.connectedIPs, key=self.connectedIPs.get) + master_candidate = min(self.connectedIPs, key=self.connectedIPs.get) + if self.connectedIPs[master_candidate] < self.creation_time: + self.master_addr = master_candidate + else: + self.master_addr = (self._addr, self._port) time.sleep(self._heartbeat) - def _connect_to_clients(self, ip_list): + def _connect_to_clients(self, ip_list: List[Tuple[str, int]]): changed_start_time = False for (ip, port) in ip_list: sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -202,7 +226,7 @@ class ConnectionManager: "utf-8" )) - def connect(self, ip_list): + def connect(self, ip_list: List[str]): if len(self.sockets) > 0: return request_handler = partial(ConnectionManagerTCPHandler, self) @@ -215,6 +239,7 @@ class ConnectionManager: self.heartbeatThread.daemon = True self.heartbeatThread.start() ip_list = [_string_to_ip_and_port(addr) for addr in ip_list] + ip_list = [addr for addr in ip_list if addr != (self._addr, self._port)] self._connect_to_clients(ip_list) if len(self.connectedIPs) > 0: received_msg = self.send_message(StandardMessages.GET_MASTER.value, next(iter(self.connectedIPs.keys()))) @@ -228,13 +253,17 @@ class ConnectionManager: def disconnect(self): # if len(self.sockets) == 0: # return + if self.heartbeatThread is not None: + self.stop_heartbeat = True + for sock in list(self.sockets.values()): + sock.shutdown(socket.SHUT_RDWR) + sock.close() + self.sockets.clear() if self.socketServerThread is not None: self.stop_socketserver = True self.socketServer.shutdown() - if self.heartbeatThread is not None: - self.stop_heartbeat = True - def send_message(self, message, address): + def send_message(self, message: str, address: Tuple[str, int]): if self.sockets.get(address) is None: return data = bytes(message, "utf-8") @@ -244,7 +273,7 @@ class ConnectionManager: return str(self.sockets[address].recv(self._buffer_size), "utf-8") def get_current_addresses(self): - return self.connectedIPs.keys() + return list(self.connectedIPs.keys()) def get_current_master(self): return self.master_addr @@ -252,8 +281,8 @@ class ConnectionManager: def get_ip(self): return self._addr, self._port - def add_listener(self, function): + def add_listener(self, function: Callable[[str], Optional[str]]): self.listeners.append(function) - def remove_listener(self, function): + def remove_listener(self, function: Callable[[str], Optional[str]]): self.listeners.remove(function) diff --git a/tests/test_connection_manager.py b/tests/test_connection_manager.py index 82f115e..86a11e7 100644 --- a/tests/test_connection_manager.py +++ b/tests/test_connection_manager.py @@ -3,6 +3,7 @@ from unittest import TestCase import swarm import random from ipaddress import IPv4Address, IPv6Address +from typing import List class TestStatics(TestCase): @@ -30,22 +31,19 @@ class TestStatics(TestCase): class TestConnections(TestCase): - # ConnectionManager cannot be tested (test client hangs after successful test) - """ def test_initial_connection(self): conn_mans = {} for i in range(10): - conn_mans[("localhost", 1000 + i)] = swarm.ConnectionManager(port=1000 + i) + conn_mans[("localhost:" + str(1000 + i))] = swarm.ConnectionManager(port=1000 + i) addresses = list(conn_mans.keys()) - to_connect = [] + to_connect: List[str] = [] for i in range(len(addresses)): conn_mans[addresses[i]].connect(to_connect) to_connect.append(addresses[i]) - master = addresses[0] + master = swarm._string_to_ip_and_port(addresses[0]) for manager in conn_mans.values(): self.assertEqual(master, manager.get_current_master()) for manager in conn_mans.values(): manager.disconnect() - """ -- GitLab