From 11aa86cd495d89a225ec7c3c78be254b90e7c0e0 Mon Sep 17 00:00:00 2001
From: percyjw-2 <joris.wachsmuth@gmx.de>
Date: Sun, 24 Jul 2022 14:25:54 +0200
Subject: [PATCH] Sockets are now properly shutdown Master decision-making is
 now fixed ConnectionManager has now ContextManager support

---
 example/ExampleUsage.py          | 23 ++++++-----
 swarm/ConnectionManager.py       | 69 +++++++++++++++++++++++---------
 tests/test_connection_manager.py | 10 ++---
 3 files changed, 66 insertions(+), 36 deletions(-)

diff --git a/example/ExampleUsage.py b/example/ExampleUsage.py
index e57c864..1535b0c 100644
--- a/example/ExampleUsage.py
+++ b/example/ExampleUsage.py
@@ -27,15 +27,18 @@ if __name__ == "__main__":
                 exit(-2)
             ip_source, port_source = parse_ip(ips[index])
 
-            connMan = ConnectionManager(addr=ip_source, port=port_source)
-            connMan.connect(list(ips))
-            try:
-                while True:
-                    print("\033[H\033[J", end="")
-                    print(f"current master: {connMan.get_current_master()}")
-                    print(f"Is Client Master? {connMan.get_current_master() == ips[index]}")
-                    sleep(1)
-            except KeyboardInterrupt:
-                connMan.disconnect()
+            connManInit = ConnectionManager(addr=ip_source, port=port_source, ip_list=ips)
+
+            with connManInit as connMan:
+                try:
+                    while True:
+                        print("\033[H\033[J", end="")
+                        print(f"current master: {connMan.get_current_master()}")
+                        print(f"client address: {ips[index]}")
+                        print(f"Is Client Master? {connMan.get_current_master() == (ip_source, port_source)}")
+                        print(f"connected IPs: {connMan.get_current_addresses()}")
+                        sleep(1)
+                except KeyboardInterrupt:
+                    print("exiting...")
         except yaml.YAMLError as exc:
             print(exc)
diff --git a/swarm/ConnectionManager.py b/swarm/ConnectionManager.py
index 0fb0a52..7d7499f 100644
--- a/swarm/ConnectionManager.py
+++ b/swarm/ConnectionManager.py
@@ -5,6 +5,7 @@ import threading
 from enum import Enum
 import re
 from functools import partial
+from typing import List, Callable, Optional, Union, Tuple
 
 try:
     from time import time_ns
@@ -33,13 +34,17 @@ class InvalidIPString(Exception):
     pass
 
 
+class NotInContextManagerMode(Exception):
+    pass
+
+
 class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
     def __init__(self, connection_manager, *args, **kwargs):
         self.connection_manager = connection_manager
         super().__init__(*args, **kwargs)
 
     # match statement is not available in Python 3.6 :(
-    def announce(self, launch_time, addr):
+    def announce(self, launch_time: str, addr: str):
         address_parsed = _string_to_ip_and_port(addr)
         self.connection_manager.sockets[address_parsed] = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         self.connection_manager.sockets[address_parsed].connect(address_parsed)
@@ -57,7 +62,6 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
     def update_launch_time(self, launch_time, addr):
         address_parsed = _string_to_ip_and_port(addr)
         self.connection_manager.connectedIPs[address_parsed] = int(launch_time)
-        self.send_message(StandardMessages.ACKNOWLEDGED.value)
 
     def heartbeat(self):
         self.send_message(StandardMessages.ACKNOWLEDGED.value)
@@ -80,7 +84,11 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
 
     def handle(self):
         while not self.connection_manager.stop_socketserver:
-            msg_recvd = str(self.request.recv(self.connection_manager._buffer_size), "utf-8").lower()
+            msg_recvd = ""
+            try:
+                msg_recvd = str(self.request.recv(self.connection_manager._buffer_size), "utf-8").lower()
+            except ConnectionError:
+                pass
             if not msg_recvd:
                 break
             msg_split = msg_recvd.split(":")
@@ -101,12 +109,12 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
             else:
                 self.default_case(msg_recvd)
 
-    def send_message(self, message):
+    def send_message(self, message: Union[str, int]):
         message = str(message).encode("utf-8")
         self.request.sendall(message)
 
 
-def _string_to_ip_and_port(message):
+def _string_to_ip_and_port(message: str) -> Tuple[str, int]:
     valid_ipv4 = re.compile(r"^(\d?\d?\d.){3}\d?\d?\d:(\d?){4}\d$")
     valid_ipv6 = re.compile(r"^([a-f\d:]+:+)+[a-f\d]+:(\d?){4}\d$")
     valid_address = re.compile(r"^(localhost)|(\*+.\*):(\d?){4}\d$")
@@ -124,11 +132,15 @@ class ConnectionManager:
                  addr="localhost",
                  port=6969,
                  heartbeat=1,
-                 buffer_size=1024):
+                 buffer_size=1024,
+                 ip_list=None):
+        if ip_list is None:
+            ip_list = []
         self._addr = addr
         self._port = port
         self._heartbeat = heartbeat
         self._buffer_size = buffer_size
+        self._ip_list: List[str] = ip_list
 
         self.creation_time = time_ns()
 
@@ -142,12 +154,20 @@ class ConnectionManager:
         self.socketServer = None
         self.listeners = []
 
+    def __enter__(self):
+        if self._ip_list is None:
+            raise NotInContextManagerMode("An IP List needs to be provided to the Constructor to use this class "
+                                          "with the 'with' keyword.")
+        self.connect(self._ip_list)
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.disconnect()
+
     def __del__(self):
         self.disconnect()
-        for sock in self.sockets.values():
-            sock.close()
 
-    def _launch_socket_server(self, address, request_handler=socketserver.BaseRequestHandler):
+    def _launch_socket_server(self, address: Tuple[str, int], request_handler=socketserver.BaseRequestHandler):
         with socketserver.ThreadingTCPServer(address, request_handler) as server:
             self.socketServer = server
             server.serve_forever()
@@ -161,20 +181,24 @@ class ConnectionManager:
                     received = str(self.sockets[address].recv(self._buffer_size), "utf-8")
                     if received.lower() != StandardMessages.ACKNOWLEDGED.value:
                         to_be_deleted.append(address)
-                except ConnectionAbortedError:
+                except ConnectionError:
                     to_be_deleted.append(address)
             for address in to_be_deleted:
                 self.connectedIPs.pop(address)
                 self.sockets[address].close()
                 self.sockets.pop(address)
             if self.master_addr not in self.connectedIPs.keys() and self.master_addr != (self._addr, self._port):
-                if len(self.connectedIPs) > 0:
+                if len(self.connectedIPs) == 0:
                     self.master_addr = (self._addr, self._port)
                 else:
-                    self.master_addr = min(self.connectedIPs, key=self.connectedIPs.get)
+                    master_candidate = min(self.connectedIPs, key=self.connectedIPs.get)
+                    if self.connectedIPs[master_candidate] < self.creation_time:
+                        self.master_addr = master_candidate
+                    else:
+                        self.master_addr = (self._addr, self._port)
             time.sleep(self._heartbeat)
 
-    def _connect_to_clients(self, ip_list):
+    def _connect_to_clients(self, ip_list: List[Tuple[str, int]]):
         changed_start_time = False
         for (ip, port) in ip_list:
             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@@ -202,7 +226,7 @@ class ConnectionManager:
                     "utf-8"
                 ))
 
-    def connect(self, ip_list):
+    def connect(self, ip_list: List[str]):
         if len(self.sockets) > 0:
             return
         request_handler = partial(ConnectionManagerTCPHandler, self)
@@ -215,6 +239,7 @@ class ConnectionManager:
         self.heartbeatThread.daemon = True
         self.heartbeatThread.start()
         ip_list = [_string_to_ip_and_port(addr) for addr in ip_list]
+        ip_list = [addr for addr in ip_list if addr != (self._addr, self._port)]
         self._connect_to_clients(ip_list)
         if len(self.connectedIPs) > 0:
             received_msg = self.send_message(StandardMessages.GET_MASTER.value, next(iter(self.connectedIPs.keys())))
@@ -228,13 +253,17 @@ class ConnectionManager:
     def disconnect(self):
         # if len(self.sockets) == 0:
         #     return
+        if self.heartbeatThread is not None:
+            self.stop_heartbeat = True
+        for sock in list(self.sockets.values()):
+            sock.shutdown(socket.SHUT_RDWR)
+            sock.close()
+        self.sockets.clear()
         if self.socketServerThread is not None:
             self.stop_socketserver = True
             self.socketServer.shutdown()
-        if self.heartbeatThread is not None:
-            self.stop_heartbeat = True
 
-    def send_message(self, message, address):
+    def send_message(self, message: str, address: Tuple[str, int]):
         if self.sockets.get(address) is None:
             return
         data = bytes(message, "utf-8")
@@ -244,7 +273,7 @@ class ConnectionManager:
         return str(self.sockets[address].recv(self._buffer_size), "utf-8")
 
     def get_current_addresses(self):
-        return self.connectedIPs.keys()
+        return list(self.connectedIPs.keys())
 
     def get_current_master(self):
         return self.master_addr
@@ -252,8 +281,8 @@ class ConnectionManager:
     def get_ip(self):
         return self._addr, self._port
 
-    def add_listener(self, function):
+    def add_listener(self, function: Callable[[str], Optional[str]]):
         self.listeners.append(function)
 
-    def remove_listener(self, function):
+    def remove_listener(self, function: Callable[[str], Optional[str]]):
         self.listeners.remove(function)
diff --git a/tests/test_connection_manager.py b/tests/test_connection_manager.py
index 82f115e..86a11e7 100644
--- a/tests/test_connection_manager.py
+++ b/tests/test_connection_manager.py
@@ -3,6 +3,7 @@ from unittest import TestCase
 import swarm
 import random
 from ipaddress import IPv4Address, IPv6Address
+from typing import List
 
 
 class TestStatics(TestCase):
@@ -30,22 +31,19 @@ class TestStatics(TestCase):
 
 
 class TestConnections(TestCase):
-    # ConnectionManager cannot be tested (test client hangs after successful test)
-    """
     def test_initial_connection(self):
         conn_mans = {}
         for i in range(10):
-            conn_mans[("localhost", 1000 + i)] = swarm.ConnectionManager(port=1000 + i)
+            conn_mans[("localhost:" + str(1000 + i))] = swarm.ConnectionManager(port=1000 + i)
 
         addresses = list(conn_mans.keys())
-        to_connect = []
+        to_connect: List[str] = []
         for i in range(len(addresses)):
             conn_mans[addresses[i]].connect(to_connect)
             to_connect.append(addresses[i])
 
-        master = addresses[0]
+        master = swarm._string_to_ip_and_port(addresses[0])
         for manager in conn_mans.values():
             self.assertEqual(master, manager.get_current_master())
         for manager in conn_mans.values():
             manager.disconnect()
-    """
-- 
GitLab