diff --git a/.gitignore b/.gitignore index 1382b14f3fcd802a5eeddd5e12ebaf1b26e0f592..7b6ccd6d05e59aee279f96eecd99825c87b14f98 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ .tox .pytest_cache swarm.egg-info -venv \ No newline at end of file +venv +build \ No newline at end of file diff --git a/swarm/ConnectionManager.py b/swarm/ConnectionManager.py index cfcb67b413433485bbeb63732835c1a0850ddb5f..a724bb8237c6172c68165b47d97817a45d2ab27f 100644 --- a/swarm/ConnectionManager.py +++ b/swarm/ConnectionManager.py @@ -4,11 +4,7 @@ import time import threading from enum import Enum import re - - -class NetworkAddressType(Enum): - IPV6 = 0 - IPV4 = 1 +from functools import partial class StandardMessages(Enum): @@ -17,7 +13,6 @@ class StandardMessages(Enum): HEARTBEAT = "heartbeat" GET_ADDRESSES = "addresses" GET_MASTER = "master" - GET_LOGIN_TIME = "login_time" class MessageToBig(Exception): @@ -28,6 +23,77 @@ class InvalidIPString(Exception): pass +class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler): + def __init__(self, connection_manager, *args, **kwargs): + self.connection_manager = connection_manager + super().__init__(*args, **kwargs) + + # match statement is not available in Python 3.6 :( + def announce(self, launch_time, addr): + address_parsed = _string_to_ip_and_port(addr) + self.connection_manager.sockets[address_parsed] = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.connection_manager.sockets[address_parsed].connect(address_parsed) + self.connection_manager.connectedIPs[address_parsed] = int(launch_time) + self.send_message(str(self.connection_manager.creation_time)) + + def heartbeat(self): + self.send_message(StandardMessages.ACKNOWLEDGED.value) + + def get_addresses(self): + address_string =\ + ",".join([addr[0] + ":" + str(addr[1]) for addr in self.connection_manager.connectedIPs.keys()]) + self.send_message(address_string) + + def get_master(self): + master_addr = self.connection_manager.master_addr + addr_str = master_addr[0] + ":" + str(master_addr[1]) + self.send_message(addr_str) + + def default_case(self, message): + for func in self.connection_manager.listeners: + return_msg = func(message) + if return_msg is not None: + self.send_message(return_msg) + + def handle(self): + while not self.connection_manager.stop_socketserver: + msg_recvd = str(self.request.recv(self.connection_manager.buffer_size), "utf-8").lower() + if not msg_recvd: + break + msg_split = msg_recvd.split(":") + cmd = msg_split[0] + msg = ":".join(msg_split[1:len(msg_split)]) + msg_args = msg.split(",") + + if cmd == StandardMessages.ANNOUNCE.value: + self.announce(*msg_args) + elif cmd == StandardMessages.HEARTBEAT.value: + self.heartbeat() + elif cmd == StandardMessages.GET_MASTER.value: + self.get_master() + elif cmd == StandardMessages.GET_ADDRESSES.value: + self.get_addresses() + else: + self.default_case(msg_recvd) + + def send_message(self, message): + message = str(message).encode("utf-8") + self.request.sendall(message) + + +def _string_to_ip_and_port(message): + valid_ipv4 = re.compile(r"^(\d?\d?\d.){3}\d?\d?\d:(\d?){4}\d$") + valid_ipv6 = re.compile(r"^([a-f\d:]+:+)+[a-f\d]+:(\d?){4}\d$") + valid_address = re.compile(r"^(localhost)|(\*+.\*):(\d?){4}\d$") + if (not valid_ipv4.match(message)) and (not valid_ipv6.match(message)) and (not valid_address.match(message)): + raise InvalidIPString(f"'{message}' is not an valid ip address") + msg_split = message.split(":") + port = msg_split[-1] + ip = ":".join(msg_split[0:-1]) + port = int(port) + return ip, port + + class ConnectionManager: def __init__(self, addr="localhost", @@ -41,93 +107,100 @@ class ConnectionManager: self.creation_time = time.time_ns() - self.socket = None + self.sockets = {} self.connectedIPs = {} self.socketServerThread = None self.heartbeatThread = None - self.master_addr = None - - @staticmethod - def _launch_socket_server(address, request_handler=socketserver.BaseRequestHandler): - with socketserver.UDPServer(address, request_handler) as server: + self.master_addr = (addr, port) + self.stop_heartbeat = False + self.stop_socketserver = False + self.socketServer = None + + def __del__(self): + for sock in self.sockets.values(): + sock.close() + + def _launch_socket_server(self, address, request_handler=socketserver.BaseRequestHandler): + with socketserver.ThreadingTCPServer(address, request_handler) as server: + self.socketServer = server server.serve_forever() def _launch_heartbeat(self): - heartbeat_socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - while True: + while not self.stop_heartbeat: to_be_deleted = [] for address in self.connectedIPs.keys(): - heartbeat_socket.sendto(bytes(StandardMessages.HEARTBEAT.value, "utf-8"), address) try: - received = str(heartbeat_socket.recv(self.buffer_size), "utf-8") + self.sockets[address].sendall(bytes(StandardMessages.HEARTBEAT.value, "utf-8")) + received = str(self.sockets[address].recv(self.buffer_size), "utf-8") if received.lower() != StandardMessages.ACKNOWLEDGED.value: to_be_deleted.append(address) - except ConnectionResetError: + except ConnectionAbortedError: to_be_deleted.append(address) for address in to_be_deleted: self.connectedIPs.pop(address) + self.sockets[address].close() + self.sockets.pop(address) if self.master_addr not in self.connectedIPs.keys() and self.master_addr != (self.addr, self.port): - self.master_addr = max(self.connectedIPs, key=self.connectedIPs.get) + if len(self.connectedIPs) > 0: + self.master_addr = (self.addr, self.port) + else: + self.master_addr = max(self.connectedIPs, key=self.connectedIPs.get) time.sleep(self.heartbeat) def _connect_to_clients(self, ip_list): for (ip, port) in ip_list: - self.socket.sendto(bytes(StandardMessages.ANNOUNCE.value, "utf-8"), (ip, port)) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: - received = str(self.socket.recv(self.buffer_size), "utf-8") - self.connectedIPs[(ip, port)] = received - except ConnectionResetError: + sock.connect((ip, port)) + self.sockets[(ip, port)] = sock + sock.sendall(bytes( + StandardMessages.ANNOUNCE.value + ":" + + str(self.creation_time) + "," + self.addr + ":" + str(self.port), + "utf-8" + )) + self.connectedIPs[(ip, port)] = str(sock.recv(self.buffer_size), "utf-8") + except ConnectionRefusedError: pass - @staticmethod - def _string_to_ip_and_port(message): - valid_ipv4 = re.compile(r"^(\d?\d?\d.){3}\d?\d?\d:(\d?){4}\d$") - valid_ipv6 = re.compile(r"^([a-f\d:]+:+)+[a-f\d]+(\d?){4}\d$") - if (not valid_ipv4.match(message)) and (not valid_ipv6.match(message)): - raise InvalidIPString - msg_split = message.split(":") - port = msg_split[-1] - ip = message.replace(":" + port, "") - port = int(port) - return ip, port - def connect(self, ip_list): - if self.socket is not None: + if len(self.sockets) > 0: return - self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - self._connect_to_clients(ip_list) - if len(self.connectedIPs) > 0: - received_msg = self.send_message(StandardMessages.GET_MASTER, next(iter(self.connectedIPs.keys()))) - self.master_addr = self._string_to_ip_and_port(received_msg) - received_msg = self.send_message(StandardMessages.GET_ADDRESSES, self.master_addr) - network_ips = [self._string_to_ip_and_port(addr) for addr in received_msg.split(",")] - network_ips = [addr for addr in network_ips if addr not in self.connectedIPs] - self._connect_to_clients(network_ips) - else: - self.master_addr = (self.addr, self.port) - self.socketServerThread = threading.Thread(target=self._launch_socket_server, args=((self.addr, self.port),)) + request_handler = partial(ConnectionManagerTCPHandler, self) + self.socketServerThread =\ + threading.Thread(target=self._launch_socket_server, args=((self.addr, self.port), request_handler)) self.socketServerThread.daemon = True self.socketServerThread.start() self.heartbeatThread = threading.Thread(target=self._launch_heartbeat) self.heartbeatThread.daemon = True self.heartbeatThread.start() + self._connect_to_clients(ip_list) + if len(self.connectedIPs) > 0: + received_msg = self.send_message(StandardMessages.GET_MASTER.value, next(iter(self.connectedIPs.keys()))) + self.master_addr = _string_to_ip_and_port(received_msg) + received_msg = self.send_message(StandardMessages.GET_ADDRESSES.value, self.master_addr) + network_ips = [_string_to_ip_and_port(addr) for addr in received_msg.split(",") if addr != ""] + network_ips = [addr for addr in network_ips if addr not in self.connectedIPs] + network_ips = [addr for addr in network_ips if addr != (self.addr, self.port)] + self._connect_to_clients(network_ips) def disconnect(self): - if self.socket is None: + if len(self.sockets) == 0: return if self.socketServerThread.is_alive(): - self.socketServerThread.terminate() + self.stop_socketserver = True + self.socketServer.shutdown() if self.heartbeatThread.is_alive(): - self.heartbeatThread.terminate() + self.stop_heartbeat = True def send_message(self, message, address): - if self.socket is None: + print(f"{message} is sent to {address}") + if self.sockets.get(address) is None: return data = bytes(message, "utf-8") if len(data) > self.buffer_size: raise MessageToBig - self.socket.sendto(data, address) - return str(self.socket.recv(self.buffer_size), "utf-8") + self.sockets[address].sendall(data) + return str(self.sockets[address].recv(self.buffer_size), "utf-8") def get_current_addresses(self): return self.connectedIPs.keys() diff --git a/swarm/__init__.py b/swarm/__init__.py index 278f67bbad2971ce21d54706a5f18265d09d8e21..5f7d96d02195942168f424a4a82e6db60b4396ab 100644 --- a/swarm/__init__.py +++ b/swarm/__init__.py @@ -5,4 +5,9 @@ Python library __version__ = "0.0.1" __author__ = 'Joris Wachsmuth' -from .ConnectionManager import * +from .ConnectionManager import\ + ConnectionManager,\ + StandardMessages,\ + InvalidIPString,\ + MessageToBig,\ + _string_to_ip_and_port diff --git a/tests/test_connection_manager.py b/tests/test_connection_manager.py index b749fcb2b93cc53426afa4bf9cb8927bbc630c87..eda6b6bc8f471bd21a25ae23a18ed40dde19f020 100644 --- a/tests/test_connection_manager.py +++ b/tests/test_connection_manager.py @@ -2,28 +2,29 @@ from unittest import TestCase import swarm import random +import time from ipaddress import IPv4Address, IPv6Address -class Test(TestCase): +class TestStatics(TestCase): def test_ipv4_str_parsing(self): for i in range(1000): addr_str = str(IPv4Address(random.getrandbits(32))) port = random.randint(1, 65535) - (ip, port_e) = swarm.ConnectionManager._string_to_ip_and_port(addr_str + ":" + str(port)) + (ip, port_e) = swarm._string_to_ip_and_port(addr_str + ":" + str(port)) self.assertEqual((ip, port_e), (addr_str, port)) def test_ipv6_str_parsing(self): for i in range(1000): addr_str = str(IPv6Address(random.getrandbits(128))) port = random.randint(1, 65535) - (ip, port_e) = swarm.ConnectionManager._string_to_ip_and_port(addr_str + ":" + str(port)) + (ip, port_e) = swarm._string_to_ip_and_port(addr_str + ":" + str(port)) self.assertEqual((ip, port_e), (addr_str, port)) def test_invalid_str_parsing(self): invalid_sting = "invalid" try: - swarm.ConnectionManager._string_to_ip_and_port(invalid_sting) + swarm._string_to_ip_and_port(invalid_sting) except swarm.InvalidIPString: return self.assertTrue(False)