Skip to content
Snippets Groups Projects
Commit ecd0f26c authored by Joris Wachsmuth's avatar Joris Wachsmuth
Browse files

Added Example and refinded master selection

parent bcfdaa48
No related branches found
No related tags found
No related merge requests found
Pipeline #20039 passed
#!/usr/bin/python3
from swarm import ConnectionManager
import sys
from time import sleep
import yaml
def parse_ip(ip_str):
split_str = ip_str.split(":")
port = split_str[-1]
ip = ":".join(split_str[0:-1])
port = int(port)
return ip, port
if __name__ == "__main__":
if len(sys.argv) != 2:
exit(-1)
with open("ip_list.yaml", "r") as stream:
try:
data = yaml.safe_load(stream)
ips = data['ips']
index = int(sys.argv[1])
if index < 0 or index >= len(ips):
exit(-2)
ip_source, port_source = parse_ip(ips[index])
connMan = ConnectionManager(addr=ip_source, port=port_source)
connMan.connect(list(ips))
try:
while True:
print("\033[H\033[J", end="")
print(f"current master: {connMan.get_current_master()}")
print(f"Is Client Master? {connMan.get_current_master() == ips[index]}")
sleep(1)
except KeyboardInterrupt:
connMan.disconnect()
except yaml.YAMLError as exc:
print(exc)
ips:
- localhost:1000
- localhost:1001
- localhost:1002
- localhost:1003
#!/usr/bin/env bash
python -m venv ./venv
source ./venv/bin/activate
pip install --upgrade setuptools wheel pyyaml
pip install ../
...@@ -6,9 +6,19 @@ from enum import Enum ...@@ -6,9 +6,19 @@ from enum import Enum
import re import re
from functools import partial from functools import partial
try:
from time import time_ns
except ImportError:
from datetime import datetime
def time_ns():
now = datetime.now()
return int(now.timestamp() * 1e9)
class StandardMessages(Enum): class StandardMessages(Enum):
ANNOUNCE = "announce" ANNOUNCE = "announce"
UPDATE_LAUNCH = "update"
ACKNOWLEDGED = "acknowledged" ACKNOWLEDGED = "acknowledged"
HEARTBEAT = "heartbeat" HEARTBEAT = "heartbeat"
GET_ADDRESSES = "addresses" GET_ADDRESSES = "addresses"
...@@ -33,8 +43,21 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler): ...@@ -33,8 +43,21 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
address_parsed = _string_to_ip_and_port(addr) address_parsed = _string_to_ip_and_port(addr)
self.connection_manager.sockets[address_parsed] = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.connection_manager.sockets[address_parsed] = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.connection_manager.sockets[address_parsed].connect(address_parsed) self.connection_manager.sockets[address_parsed].connect(address_parsed)
announced_launch_time = int(launch_time)
updated_launch_time = False
while announced_launch_time in self.connection_manager.connectedIPs.values():
updated_launch_time = True
announced_launch_time = announced_launch_time + 1
self.connection_manager.connectedIPs[address_parsed] = announced_launch_time
updated_time = ""
if updated_launch_time:
updated_time = "," + str(announced_launch_time)
self.send_message(str(self.connection_manager.creation_time) + updated_time)
def update_launch_time(self, launch_time, addr):
address_parsed = _string_to_ip_and_port(addr)
self.connection_manager.connectedIPs[address_parsed] = int(launch_time) self.connection_manager.connectedIPs[address_parsed] = int(launch_time)
self.send_message(str(self.connection_manager.creation_time)) self.send_message(StandardMessages.ACKNOWLEDGED.value)
def heartbeat(self): def heartbeat(self):
self.send_message(StandardMessages.ACKNOWLEDGED.value) self.send_message(StandardMessages.ACKNOWLEDGED.value)
...@@ -57,7 +80,7 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler): ...@@ -57,7 +80,7 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
def handle(self): def handle(self):
while not self.connection_manager.stop_socketserver: while not self.connection_manager.stop_socketserver:
msg_recvd = str(self.request.recv(self.connection_manager.buffer_size), "utf-8").lower() msg_recvd = str(self.request.recv(self.connection_manager._buffer_size), "utf-8").lower()
if not msg_recvd: if not msg_recvd:
break break
msg_split = msg_recvd.split(":") msg_split = msg_recvd.split(":")
...@@ -67,6 +90,8 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler): ...@@ -67,6 +90,8 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
if cmd == StandardMessages.ANNOUNCE.value: if cmd == StandardMessages.ANNOUNCE.value:
self.announce(*msg_args) self.announce(*msg_args)
elif cmd == StandardMessages.UPDATE_LAUNCH.value:
self.update_launch_time(*msg_args)
elif cmd == StandardMessages.HEARTBEAT.value: elif cmd == StandardMessages.HEARTBEAT.value:
self.heartbeat() self.heartbeat()
elif cmd == StandardMessages.GET_MASTER.value: elif cmd == StandardMessages.GET_MASTER.value:
...@@ -100,12 +125,12 @@ class ConnectionManager: ...@@ -100,12 +125,12 @@ class ConnectionManager:
port=6969, port=6969,
heartbeat=1, heartbeat=1,
buffer_size=1024): buffer_size=1024):
self.addr = addr self._addr = addr
self.port = port self._port = port
self.heartbeat = heartbeat self._heartbeat = heartbeat
self.buffer_size = buffer_size self._buffer_size = buffer_size
self.creation_time = time.time_ns() self.creation_time = time_ns()
self.sockets = {} self.sockets = {}
self.connectedIPs = {} self.connectedIPs = {}
...@@ -115,8 +140,10 @@ class ConnectionManager: ...@@ -115,8 +140,10 @@ class ConnectionManager:
self.stop_heartbeat = False self.stop_heartbeat = False
self.stop_socketserver = False self.stop_socketserver = False
self.socketServer = None self.socketServer = None
self.listeners = []
def __del__(self): def __del__(self):
self.disconnect()
for sock in self.sockets.values(): for sock in self.sockets.values():
sock.close() sock.close()
...@@ -131,7 +158,7 @@ class ConnectionManager: ...@@ -131,7 +158,7 @@ class ConnectionManager:
for address in self.connectedIPs.keys(): for address in self.connectedIPs.keys():
try: try:
self.sockets[address].sendall(bytes(StandardMessages.HEARTBEAT.value, "utf-8")) self.sockets[address].sendall(bytes(StandardMessages.HEARTBEAT.value, "utf-8"))
received = str(self.sockets[address].recv(self.buffer_size), "utf-8") received = str(self.sockets[address].recv(self._buffer_size), "utf-8")
if received.lower() != StandardMessages.ACKNOWLEDGED.value: if received.lower() != StandardMessages.ACKNOWLEDGED.value:
to_be_deleted.append(address) to_be_deleted.append(address)
except ConnectionAbortedError: except ConnectionAbortedError:
...@@ -140,14 +167,15 @@ class ConnectionManager: ...@@ -140,14 +167,15 @@ class ConnectionManager:
self.connectedIPs.pop(address) self.connectedIPs.pop(address)
self.sockets[address].close() self.sockets[address].close()
self.sockets.pop(address) self.sockets.pop(address)
if self.master_addr not in self.connectedIPs.keys() and self.master_addr != (self.addr, self.port): if self.master_addr not in self.connectedIPs.keys() and self.master_addr != (self._addr, self._port):
if len(self.connectedIPs) > 0: if len(self.connectedIPs) > 0:
self.master_addr = (self.addr, self.port) self.master_addr = (self._addr, self._port)
else: else:
self.master_addr = max(self.connectedIPs, key=self.connectedIPs.get) self.master_addr = min(self.connectedIPs, key=self.connectedIPs.get)
time.sleep(self.heartbeat) time.sleep(self._heartbeat)
def _connect_to_clients(self, ip_list): def _connect_to_clients(self, ip_list):
changed_start_time = False
for (ip, port) in ip_list: for (ip, port) in ip_list:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try: try:
...@@ -155,24 +183,38 @@ class ConnectionManager: ...@@ -155,24 +183,38 @@ class ConnectionManager:
self.sockets[(ip, port)] = sock self.sockets[(ip, port)] = sock
sock.sendall(bytes( sock.sendall(bytes(
StandardMessages.ANNOUNCE.value + ":" + StandardMessages.ANNOUNCE.value + ":" +
str(self.creation_time) + "," + self.addr + ":" + str(self.port), str(self.creation_time) + "," + self._addr + ":" + str(self._port),
"utf-8" "utf-8"
)) ))
self.connectedIPs[(ip, port)] = str(sock.recv(self.buffer_size), "utf-8") recv_msg = str(sock.recv(self._buffer_size), "utf-8")
recv_msgs = recv_msg.split(",")
self.connectedIPs[(ip, port)] = recv_msgs[0]
if len(recv_msgs) > 1:
changed_start_time = True
self.creation_time = recv_msgs[1]
except ConnectionRefusedError: except ConnectionRefusedError:
pass pass
if changed_start_time:
for sock in self.sockets.values():
sock.sendall(bytes(
StandardMessages.UPDATE_LAUNCH.value + ":" +
str(self.creation_time) + "," + self._addr + ":" + str(self._port),
"utf-8"
))
def connect(self, ip_list): def connect(self, ip_list):
if len(self.sockets) > 0: if len(self.sockets) > 0:
return return
request_handler = partial(ConnectionManagerTCPHandler, self) request_handler = partial(ConnectionManagerTCPHandler, self)
self.socketServerThread =\ self.socketServerThread =\
threading.Thread(target=self._launch_socket_server, args=((self.addr, self.port), request_handler)) threading.Thread(target=self._launch_socket_server, args=((self._addr, self._port), request_handler))
self.socketServerThread.daemon = True self.socketServerThread.daemon = True
self.socketServerThread.start() self.socketServerThread.start()
self.heartbeatThread = threading.Thread(target=self._launch_heartbeat) self.heartbeatThread = threading.Thread(target=self._launch_heartbeat)
self.heartbeatThread.daemon = True self.heartbeatThread.daemon = True
self.heartbeatThread.start() self.heartbeatThread.start()
ip_list = [_string_to_ip_and_port(addr) for addr in ip_list]
self._connect_to_clients(ip_list) self._connect_to_clients(ip_list)
if len(self.connectedIPs) > 0: if len(self.connectedIPs) > 0:
received_msg = self.send_message(StandardMessages.GET_MASTER.value, next(iter(self.connectedIPs.keys()))) received_msg = self.send_message(StandardMessages.GET_MASTER.value, next(iter(self.connectedIPs.keys())))
...@@ -180,27 +222,26 @@ class ConnectionManager: ...@@ -180,27 +222,26 @@ class ConnectionManager:
received_msg = self.send_message(StandardMessages.GET_ADDRESSES.value, self.master_addr) received_msg = self.send_message(StandardMessages.GET_ADDRESSES.value, self.master_addr)
network_ips = [_string_to_ip_and_port(addr) for addr in received_msg.split(",") if addr != ""] network_ips = [_string_to_ip_and_port(addr) for addr in received_msg.split(",") if addr != ""]
network_ips = [addr for addr in network_ips if addr not in self.connectedIPs] network_ips = [addr for addr in network_ips if addr not in self.connectedIPs]
network_ips = [addr for addr in network_ips if addr != (self.addr, self.port)] network_ips = [addr for addr in network_ips if addr != (self._addr, self._port)]
self._connect_to_clients(network_ips) self._connect_to_clients(network_ips)
def disconnect(self): def disconnect(self):
if len(self.sockets) == 0: # if len(self.sockets) == 0:
return # return
if self.socketServerThread.is_alive(): if self.socketServerThread is not None:
self.stop_socketserver = True self.stop_socketserver = True
self.socketServer.shutdown() self.socketServer.shutdown()
if self.heartbeatThread.is_alive(): if self.heartbeatThread is not None:
self.stop_heartbeat = True self.stop_heartbeat = True
def send_message(self, message, address): def send_message(self, message, address):
print(f"{message} is sent to {address}")
if self.sockets.get(address) is None: if self.sockets.get(address) is None:
return return
data = bytes(message, "utf-8") data = bytes(message, "utf-8")
if len(data) > self.buffer_size: if len(data) > self._buffer_size:
raise MessageToBig raise MessageToBig
self.sockets[address].sendall(data) self.sockets[address].sendall(data)
return str(self.sockets[address].recv(self.buffer_size), "utf-8") return str(self.sockets[address].recv(self._buffer_size), "utf-8")
def get_current_addresses(self): def get_current_addresses(self):
return self.connectedIPs.keys() return self.connectedIPs.keys()
...@@ -209,4 +250,10 @@ class ConnectionManager: ...@@ -209,4 +250,10 @@ class ConnectionManager:
return self.master_addr return self.master_addr
def get_ip(self): def get_ip(self):
return self.addr, self.port return self._addr, self._port
def add_listener(self, function):
self.listeners.append(function)
def remove_listener(self, function):
self.listeners.remove(function)
...@@ -2,7 +2,6 @@ from unittest import TestCase ...@@ -2,7 +2,6 @@ from unittest import TestCase
import swarm import swarm
import random import random
import time
from ipaddress import IPv4Address, IPv6Address from ipaddress import IPv4Address, IPv6Address
...@@ -28,3 +27,25 @@ class TestStatics(TestCase): ...@@ -28,3 +27,25 @@ class TestStatics(TestCase):
except swarm.InvalidIPString: except swarm.InvalidIPString:
return return
self.assertTrue(False) self.assertTrue(False)
class TestConnections(TestCase):
# ConnectionManager cannot be tested (test client hangs after successful test)
"""
def test_initial_connection(self):
conn_mans = {}
for i in range(10):
conn_mans[("localhost", 1000 + i)] = swarm.ConnectionManager(port=1000 + i)
addresses = list(conn_mans.keys())
to_connect = []
for i in range(len(addresses)):
conn_mans[addresses[i]].connect(to_connect)
to_connect.append(addresses[i])
master = addresses[0]
for manager in conn_mans.values():
self.assertEqual(master, manager.get_current_master())
for manager in conn_mans.values():
manager.disconnect()
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment