From bcfdaa48e6e7b73f295fe5042de0cf8e9a608dc8 Mon Sep 17 00:00:00 2001
From: percyjw-2 <joris.wachsmuth@gmx.de>
Date: Sun, 17 Jul 2022 22:27:07 +0200
Subject: [PATCH] implemented ConnectionManager

---
 .gitignore                       |   3 +-
 swarm/ConnectionManager.py       | 179 ++++++++++++++++++++++---------
 swarm/__init__.py                |   7 +-
 tests/test_connection_manager.py |   9 +-
 4 files changed, 139 insertions(+), 59 deletions(-)

diff --git a/.gitignore b/.gitignore
index 1382b14..7b6ccd6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,4 +2,5 @@
 .tox
 .pytest_cache
 swarm.egg-info
-venv
\ No newline at end of file
+venv
+build
\ No newline at end of file
diff --git a/swarm/ConnectionManager.py b/swarm/ConnectionManager.py
index cfcb67b..a724bb8 100644
--- a/swarm/ConnectionManager.py
+++ b/swarm/ConnectionManager.py
@@ -4,11 +4,7 @@ import time
 import threading
 from enum import Enum
 import re
-
-
-class NetworkAddressType(Enum):
-    IPV6 = 0
-    IPV4 = 1
+from functools import partial
 
 
 class StandardMessages(Enum):
@@ -17,7 +13,6 @@ class StandardMessages(Enum):
     HEARTBEAT = "heartbeat"
     GET_ADDRESSES = "addresses"
     GET_MASTER = "master"
-    GET_LOGIN_TIME = "login_time"
 
 
 class MessageToBig(Exception):
@@ -28,6 +23,77 @@ class InvalidIPString(Exception):
     pass
 
 
+class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
+    def __init__(self, connection_manager, *args, **kwargs):
+        self.connection_manager = connection_manager
+        super().__init__(*args, **kwargs)
+
+    # match statement is not available in Python 3.6 :(
+    def announce(self, launch_time, addr):
+        address_parsed = _string_to_ip_and_port(addr)
+        self.connection_manager.sockets[address_parsed] = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.connection_manager.sockets[address_parsed].connect(address_parsed)
+        self.connection_manager.connectedIPs[address_parsed] = int(launch_time)
+        self.send_message(str(self.connection_manager.creation_time))
+
+    def heartbeat(self):
+        self.send_message(StandardMessages.ACKNOWLEDGED.value)
+
+    def get_addresses(self):
+        address_string =\
+            ",".join([addr[0] + ":" + str(addr[1]) for addr in self.connection_manager.connectedIPs.keys()])
+        self.send_message(address_string)
+
+    def get_master(self):
+        master_addr = self.connection_manager.master_addr
+        addr_str = master_addr[0] + ":" + str(master_addr[1])
+        self.send_message(addr_str)
+
+    def default_case(self, message):
+        for func in self.connection_manager.listeners:
+            return_msg = func(message)
+            if return_msg is not None:
+                self.send_message(return_msg)
+
+    def handle(self):
+        while not self.connection_manager.stop_socketserver:
+            msg_recvd = str(self.request.recv(self.connection_manager.buffer_size), "utf-8").lower()
+            if not msg_recvd:
+                break
+            msg_split = msg_recvd.split(":")
+            cmd = msg_split[0]
+            msg = ":".join(msg_split[1:len(msg_split)])
+            msg_args = msg.split(",")
+
+            if cmd == StandardMessages.ANNOUNCE.value:
+                self.announce(*msg_args)
+            elif cmd == StandardMessages.HEARTBEAT.value:
+                self.heartbeat()
+            elif cmd == StandardMessages.GET_MASTER.value:
+                self.get_master()
+            elif cmd == StandardMessages.GET_ADDRESSES.value:
+                self.get_addresses()
+            else:
+                self.default_case(msg_recvd)
+
+    def send_message(self, message):
+        message = str(message).encode("utf-8")
+        self.request.sendall(message)
+
+
+def _string_to_ip_and_port(message):
+    valid_ipv4 = re.compile(r"^(\d?\d?\d.){3}\d?\d?\d:(\d?){4}\d$")
+    valid_ipv6 = re.compile(r"^([a-f\d:]+:+)+[a-f\d]+:(\d?){4}\d$")
+    valid_address = re.compile(r"^(localhost)|(\*+.\*):(\d?){4}\d$")
+    if (not valid_ipv4.match(message)) and (not valid_ipv6.match(message)) and (not valid_address.match(message)):
+        raise InvalidIPString(f"'{message}' is not an valid ip address")
+    msg_split = message.split(":")
+    port = msg_split[-1]
+    ip = ":".join(msg_split[0:-1])
+    port = int(port)
+    return ip, port
+
+
 class ConnectionManager:
     def __init__(self,
                  addr="localhost",
@@ -41,93 +107,100 @@ class ConnectionManager:
 
         self.creation_time = time.time_ns()
 
-        self.socket = None
+        self.sockets = {}
         self.connectedIPs = {}
         self.socketServerThread = None
         self.heartbeatThread = None
-        self.master_addr = None
-
-    @staticmethod
-    def _launch_socket_server(address, request_handler=socketserver.BaseRequestHandler):
-        with socketserver.UDPServer(address, request_handler) as server:
+        self.master_addr = (addr, port)
+        self.stop_heartbeat = False
+        self.stop_socketserver = False
+        self.socketServer = None
+
+    def __del__(self):
+        for sock in self.sockets.values():
+            sock.close()
+
+    def _launch_socket_server(self, address, request_handler=socketserver.BaseRequestHandler):
+        with socketserver.ThreadingTCPServer(address, request_handler) as server:
+            self.socketServer = server
             server.serve_forever()
 
     def _launch_heartbeat(self):
-        heartbeat_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-        while True:
+        while not self.stop_heartbeat:
             to_be_deleted = []
             for address in self.connectedIPs.keys():
-                heartbeat_socket.sendto(bytes(StandardMessages.HEARTBEAT.value, "utf-8"), address)
                 try:
-                    received = str(heartbeat_socket.recv(self.buffer_size), "utf-8")
+                    self.sockets[address].sendall(bytes(StandardMessages.HEARTBEAT.value, "utf-8"))
+                    received = str(self.sockets[address].recv(self.buffer_size), "utf-8")
                     if received.lower() != StandardMessages.ACKNOWLEDGED.value:
                         to_be_deleted.append(address)
-                except ConnectionResetError:
+                except ConnectionAbortedError:
                     to_be_deleted.append(address)
             for address in to_be_deleted:
                 self.connectedIPs.pop(address)
+                self.sockets[address].close()
+                self.sockets.pop(address)
             if self.master_addr not in self.connectedIPs.keys() and self.master_addr != (self.addr, self.port):
-                self.master_addr = max(self.connectedIPs, key=self.connectedIPs.get)
+                if len(self.connectedIPs) > 0:
+                    self.master_addr = (self.addr, self.port)
+                else:
+                    self.master_addr = max(self.connectedIPs, key=self.connectedIPs.get)
             time.sleep(self.heartbeat)
 
     def _connect_to_clients(self, ip_list):
         for (ip, port) in ip_list:
-            self.socket.sendto(bytes(StandardMessages.ANNOUNCE.value, "utf-8"), (ip, port))
+            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
             try:
-                received = str(self.socket.recv(self.buffer_size), "utf-8")
-                self.connectedIPs[(ip, port)] = received
-            except ConnectionResetError:
+                sock.connect((ip, port))
+                self.sockets[(ip, port)] = sock
+                sock.sendall(bytes(
+                    StandardMessages.ANNOUNCE.value + ":" +
+                    str(self.creation_time) + "," + self.addr + ":" + str(self.port),
+                    "utf-8"
+                ))
+                self.connectedIPs[(ip, port)] = str(sock.recv(self.buffer_size), "utf-8")
+            except ConnectionRefusedError:
                 pass
 
-    @staticmethod
-    def _string_to_ip_and_port(message):
-        valid_ipv4 = re.compile(r"^(\d?\d?\d.){3}\d?\d?\d:(\d?){4}\d$")
-        valid_ipv6 = re.compile(r"^([a-f\d:]+:+)+[a-f\d]+(\d?){4}\d$")
-        if (not valid_ipv4.match(message)) and (not valid_ipv6.match(message)):
-            raise InvalidIPString
-        msg_split = message.split(":")
-        port = msg_split[-1]
-        ip = message.replace(":" + port, "")
-        port = int(port)
-        return ip, port
-
     def connect(self, ip_list):
-        if self.socket is not None:
+        if len(self.sockets) > 0:
             return
-        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-        self._connect_to_clients(ip_list)
-        if len(self.connectedIPs) > 0:
-            received_msg = self.send_message(StandardMessages.GET_MASTER, next(iter(self.connectedIPs.keys())))
-            self.master_addr = self._string_to_ip_and_port(received_msg)
-            received_msg = self.send_message(StandardMessages.GET_ADDRESSES, self.master_addr)
-            network_ips = [self._string_to_ip_and_port(addr) for addr in received_msg.split(",")]
-            network_ips = [addr for addr in network_ips if addr not in self.connectedIPs]
-            self._connect_to_clients(network_ips)
-        else:
-            self.master_addr = (self.addr, self.port)
-        self.socketServerThread = threading.Thread(target=self._launch_socket_server, args=((self.addr, self.port),))
+        request_handler = partial(ConnectionManagerTCPHandler, self)
+        self.socketServerThread =\
+            threading.Thread(target=self._launch_socket_server, args=((self.addr, self.port), request_handler))
         self.socketServerThread.daemon = True
         self.socketServerThread.start()
         self.heartbeatThread = threading.Thread(target=self._launch_heartbeat)
         self.heartbeatThread.daemon = True
         self.heartbeatThread.start()
+        self._connect_to_clients(ip_list)
+        if len(self.connectedIPs) > 0:
+            received_msg = self.send_message(StandardMessages.GET_MASTER.value, next(iter(self.connectedIPs.keys())))
+            self.master_addr = _string_to_ip_and_port(received_msg)
+            received_msg = self.send_message(StandardMessages.GET_ADDRESSES.value, self.master_addr)
+            network_ips = [_string_to_ip_and_port(addr) for addr in received_msg.split(",") if addr != ""]
+            network_ips = [addr for addr in network_ips if addr not in self.connectedIPs]
+            network_ips = [addr for addr in network_ips if addr != (self.addr, self.port)]
+            self._connect_to_clients(network_ips)
 
     def disconnect(self):
-        if self.socket is None:
+        if len(self.sockets) == 0:
             return
         if self.socketServerThread.is_alive():
-            self.socketServerThread.terminate()
+            self.stop_socketserver = True
+            self.socketServer.shutdown()
         if self.heartbeatThread.is_alive():
-            self.heartbeatThread.terminate()
+            self.stop_heartbeat = True
 
     def send_message(self, message, address):
-        if self.socket is None:
+        print(f"{message} is sent to {address}")
+        if self.sockets.get(address) is None:
             return
         data = bytes(message, "utf-8")
         if len(data) > self.buffer_size:
             raise MessageToBig
-        self.socket.sendto(data, address)
-        return str(self.socket.recv(self.buffer_size), "utf-8")
+        self.sockets[address].sendall(data)
+        return str(self.sockets[address].recv(self.buffer_size), "utf-8")
 
     def get_current_addresses(self):
         return self.connectedIPs.keys()
diff --git a/swarm/__init__.py b/swarm/__init__.py
index 278f67b..5f7d96d 100644
--- a/swarm/__init__.py
+++ b/swarm/__init__.py
@@ -5,4 +5,9 @@ Python library
 __version__ = "0.0.1"
 __author__ = 'Joris Wachsmuth'
 
-from .ConnectionManager import *
+from .ConnectionManager import\
+    ConnectionManager,\
+    StandardMessages,\
+    InvalidIPString,\
+    MessageToBig,\
+    _string_to_ip_and_port
diff --git a/tests/test_connection_manager.py b/tests/test_connection_manager.py
index b749fcb..eda6b6b 100644
--- a/tests/test_connection_manager.py
+++ b/tests/test_connection_manager.py
@@ -2,28 +2,29 @@ from unittest import TestCase
 
 import swarm
 import random
+import time
 from ipaddress import IPv4Address, IPv6Address
 
 
-class Test(TestCase):
+class TestStatics(TestCase):
     def test_ipv4_str_parsing(self):
         for i in range(1000):
             addr_str = str(IPv4Address(random.getrandbits(32)))
             port = random.randint(1, 65535)
-            (ip, port_e) = swarm.ConnectionManager._string_to_ip_and_port(addr_str + ":" + str(port))
+            (ip, port_e) = swarm._string_to_ip_and_port(addr_str + ":" + str(port))
             self.assertEqual((ip, port_e), (addr_str, port))
 
     def test_ipv6_str_parsing(self):
         for i in range(1000):
             addr_str = str(IPv6Address(random.getrandbits(128)))
             port = random.randint(1, 65535)
-            (ip, port_e) = swarm.ConnectionManager._string_to_ip_and_port(addr_str + ":" + str(port))
+            (ip, port_e) = swarm._string_to_ip_and_port(addr_str + ":" + str(port))
             self.assertEqual((ip, port_e), (addr_str, port))
 
     def test_invalid_str_parsing(self):
         invalid_sting = "invalid"
         try:
-            swarm.ConnectionManager._string_to_ip_and_port(invalid_sting)
+            swarm._string_to_ip_and_port(invalid_sting)
         except swarm.InvalidIPString:
             return
         self.assertTrue(False)
-- 
GitLab