Skip to content
Snippets Groups Projects
Commit ecd0f26c authored by Joris Wachsmuth's avatar Joris Wachsmuth
Browse files

Added Example and refinded master selection

parent bcfdaa48
Branches primo-legacy
No related merge requests found
Pipeline #20039 passed
#!/usr/bin/python3
from swarm import ConnectionManager
import sys
from time import sleep
import yaml
def parse_ip(ip_str):
split_str = ip_str.split(":")
port = split_str[-1]
ip = ":".join(split_str[0:-1])
port = int(port)
return ip, port
if __name__ == "__main__":
if len(sys.argv) != 2:
exit(-1)
with open("ip_list.yaml", "r") as stream:
try:
data = yaml.safe_load(stream)
ips = data['ips']
index = int(sys.argv[1])
if index < 0 or index >= len(ips):
exit(-2)
ip_source, port_source = parse_ip(ips[index])
connMan = ConnectionManager(addr=ip_source, port=port_source)
connMan.connect(list(ips))
try:
while True:
print("\033[H\033[J", end="")
print(f"current master: {connMan.get_current_master()}")
print(f"Is Client Master? {connMan.get_current_master() == ips[index]}")
sleep(1)
except KeyboardInterrupt:
connMan.disconnect()
except yaml.YAMLError as exc:
print(exc)
ips:
- localhost:1000
- localhost:1001
- localhost:1002
- localhost:1003
#!/usr/bin/env bash
python -m venv ./venv
source ./venv/bin/activate
pip install --upgrade setuptools wheel pyyaml
pip install ../
......@@ -6,9 +6,19 @@ from enum import Enum
import re
from functools import partial
try:
from time import time_ns
except ImportError:
from datetime import datetime
def time_ns():
now = datetime.now()
return int(now.timestamp() * 1e9)
class StandardMessages(Enum):
ANNOUNCE = "announce"
UPDATE_LAUNCH = "update"
ACKNOWLEDGED = "acknowledged"
HEARTBEAT = "heartbeat"
GET_ADDRESSES = "addresses"
......@@ -33,8 +43,21 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
address_parsed = _string_to_ip_and_port(addr)
self.connection_manager.sockets[address_parsed] = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.connection_manager.sockets[address_parsed].connect(address_parsed)
announced_launch_time = int(launch_time)
updated_launch_time = False
while announced_launch_time in self.connection_manager.connectedIPs.values():
updated_launch_time = True
announced_launch_time = announced_launch_time + 1
self.connection_manager.connectedIPs[address_parsed] = announced_launch_time
updated_time = ""
if updated_launch_time:
updated_time = "," + str(announced_launch_time)
self.send_message(str(self.connection_manager.creation_time) + updated_time)
def update_launch_time(self, launch_time, addr):
address_parsed = _string_to_ip_and_port(addr)
self.connection_manager.connectedIPs[address_parsed] = int(launch_time)
self.send_message(str(self.connection_manager.creation_time))
self.send_message(StandardMessages.ACKNOWLEDGED.value)
def heartbeat(self):
self.send_message(StandardMessages.ACKNOWLEDGED.value)
......@@ -57,7 +80,7 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
def handle(self):
while not self.connection_manager.stop_socketserver:
msg_recvd = str(self.request.recv(self.connection_manager.buffer_size), "utf-8").lower()
msg_recvd = str(self.request.recv(self.connection_manager._buffer_size), "utf-8").lower()
if not msg_recvd:
break
msg_split = msg_recvd.split(":")
......@@ -67,6 +90,8 @@ class ConnectionManagerTCPHandler(socketserver.BaseRequestHandler):
if cmd == StandardMessages.ANNOUNCE.value:
self.announce(*msg_args)
elif cmd == StandardMessages.UPDATE_LAUNCH.value:
self.update_launch_time(*msg_args)
elif cmd == StandardMessages.HEARTBEAT.value:
self.heartbeat()
elif cmd == StandardMessages.GET_MASTER.value:
......@@ -100,12 +125,12 @@ class ConnectionManager:
port=6969,
heartbeat=1,
buffer_size=1024):
self.addr = addr
self.port = port
self.heartbeat = heartbeat
self.buffer_size = buffer_size
self._addr = addr
self._port = port
self._heartbeat = heartbeat
self._buffer_size = buffer_size
self.creation_time = time.time_ns()
self.creation_time = time_ns()
self.sockets = {}
self.connectedIPs = {}
......@@ -115,8 +140,10 @@ class ConnectionManager:
self.stop_heartbeat = False
self.stop_socketserver = False
self.socketServer = None
self.listeners = []
def __del__(self):
self.disconnect()
for sock in self.sockets.values():
sock.close()
......@@ -131,7 +158,7 @@ class ConnectionManager:
for address in self.connectedIPs.keys():
try:
self.sockets[address].sendall(bytes(StandardMessages.HEARTBEAT.value, "utf-8"))
received = str(self.sockets[address].recv(self.buffer_size), "utf-8")
received = str(self.sockets[address].recv(self._buffer_size), "utf-8")
if received.lower() != StandardMessages.ACKNOWLEDGED.value:
to_be_deleted.append(address)
except ConnectionAbortedError:
......@@ -140,14 +167,15 @@ class ConnectionManager:
self.connectedIPs.pop(address)
self.sockets[address].close()
self.sockets.pop(address)
if self.master_addr not in self.connectedIPs.keys() and self.master_addr != (self.addr, self.port):
if self.master_addr not in self.connectedIPs.keys() and self.master_addr != (self._addr, self._port):
if len(self.connectedIPs) > 0:
self.master_addr = (self.addr, self.port)
self.master_addr = (self._addr, self._port)
else:
self.master_addr = max(self.connectedIPs, key=self.connectedIPs.get)
time.sleep(self.heartbeat)
self.master_addr = min(self.connectedIPs, key=self.connectedIPs.get)
time.sleep(self._heartbeat)
def _connect_to_clients(self, ip_list):
changed_start_time = False
for (ip, port) in ip_list:
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
......@@ -155,24 +183,38 @@ class ConnectionManager:
self.sockets[(ip, port)] = sock
sock.sendall(bytes(
StandardMessages.ANNOUNCE.value + ":" +
str(self.creation_time) + "," + self.addr + ":" + str(self.port),
str(self.creation_time) + "," + self._addr + ":" + str(self._port),
"utf-8"
))
self.connectedIPs[(ip, port)] = str(sock.recv(self.buffer_size), "utf-8")
recv_msg = str(sock.recv(self._buffer_size), "utf-8")
recv_msgs = recv_msg.split(",")
self.connectedIPs[(ip, port)] = recv_msgs[0]
if len(recv_msgs) > 1:
changed_start_time = True
self.creation_time = recv_msgs[1]
except ConnectionRefusedError:
pass
if changed_start_time:
for sock in self.sockets.values():
sock.sendall(bytes(
StandardMessages.UPDATE_LAUNCH.value + ":" +
str(self.creation_time) + "," + self._addr + ":" + str(self._port),
"utf-8"
))
def connect(self, ip_list):
if len(self.sockets) > 0:
return
request_handler = partial(ConnectionManagerTCPHandler, self)
self.socketServerThread =\
threading.Thread(target=self._launch_socket_server, args=((self.addr, self.port), request_handler))
threading.Thread(target=self._launch_socket_server, args=((self._addr, self._port), request_handler))
self.socketServerThread.daemon = True
self.socketServerThread.start()
self.heartbeatThread = threading.Thread(target=self._launch_heartbeat)
self.heartbeatThread.daemon = True
self.heartbeatThread.start()
ip_list = [_string_to_ip_and_port(addr) for addr in ip_list]
self._connect_to_clients(ip_list)
if len(self.connectedIPs) > 0:
received_msg = self.send_message(StandardMessages.GET_MASTER.value, next(iter(self.connectedIPs.keys())))
......@@ -180,27 +222,26 @@ class ConnectionManager:
received_msg = self.send_message(StandardMessages.GET_ADDRESSES.value, self.master_addr)
network_ips = [_string_to_ip_and_port(addr) for addr in received_msg.split(",") if addr != ""]
network_ips = [addr for addr in network_ips if addr not in self.connectedIPs]
network_ips = [addr for addr in network_ips if addr != (self.addr, self.port)]
network_ips = [addr for addr in network_ips if addr != (self._addr, self._port)]
self._connect_to_clients(network_ips)
def disconnect(self):
if len(self.sockets) == 0:
return
if self.socketServerThread.is_alive():
# if len(self.sockets) == 0:
# return
if self.socketServerThread is not None:
self.stop_socketserver = True
self.socketServer.shutdown()
if self.heartbeatThread.is_alive():
if self.heartbeatThread is not None:
self.stop_heartbeat = True
def send_message(self, message, address):
print(f"{message} is sent to {address}")
if self.sockets.get(address) is None:
return
data = bytes(message, "utf-8")
if len(data) > self.buffer_size:
if len(data) > self._buffer_size:
raise MessageToBig
self.sockets[address].sendall(data)
return str(self.sockets[address].recv(self.buffer_size), "utf-8")
return str(self.sockets[address].recv(self._buffer_size), "utf-8")
def get_current_addresses(self):
return self.connectedIPs.keys()
......@@ -209,4 +250,10 @@ class ConnectionManager:
return self.master_addr
def get_ip(self):
return self.addr, self.port
return self._addr, self._port
def add_listener(self, function):
self.listeners.append(function)
def remove_listener(self, function):
self.listeners.remove(function)
......@@ -2,7 +2,6 @@ from unittest import TestCase
import swarm
import random
import time
from ipaddress import IPv4Address, IPv6Address
......@@ -28,3 +27,25 @@ class TestStatics(TestCase):
except swarm.InvalidIPString:
return
self.assertTrue(False)
class TestConnections(TestCase):
# ConnectionManager cannot be tested (test client hangs after successful test)
"""
def test_initial_connection(self):
conn_mans = {}
for i in range(10):
conn_mans[("localhost", 1000 + i)] = swarm.ConnectionManager(port=1000 + i)
addresses = list(conn_mans.keys())
to_connect = []
for i in range(len(addresses)):
conn_mans[addresses[i]].connect(to_connect)
to_connect.append(addresses[i])
master = addresses[0]
for manager in conn_mans.values():
self.assertEqual(master, manager.get_current_master())
for manager in conn_mans.values():
manager.disconnect()
"""
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment