From 8d82c0860e465de822c6159751189c4b58ecb29e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Sun, 28 Jan 2024 12:37:00 +0100
Subject: [PATCH] Add manager ID argument to Overcooked simulator

An optional argument for the list of manager IDs was added across the Overcooked simulator. These IDs are now used for authorization purposes when creating environments. Unauthorized attempts now return HTTP error 403. This helps enhance security and control over environment management in the simulation.
---
 overcooked_simulator/__main__.py              |  6 ++--
 overcooked_simulator/game_server.py           | 32 ++++++++++++++-----
 .../gui_2d_vis/overcooked_gui.py              | 26 +++++++++------
 overcooked_simulator/utils.py                 | 12 +++++++
 4 files changed, 56 insertions(+), 20 deletions(-)

diff --git a/overcooked_simulator/__main__.py b/overcooked_simulator/__main__.py
index f81398ec..f68c506d 100644
--- a/overcooked_simulator/__main__.py
+++ b/overcooked_simulator/__main__.py
@@ -5,19 +5,20 @@ from multiprocessing import Process
 from overcooked_simulator.utils import (
     url_and_port_arguments,
     disable_websocket_logging_arguments,
+    add_list_of_manager_ids_arguments,
 )
 
 
 def start_game_server(cli_args):
     from overcooked_simulator.game_server import main
 
-    main(cli_args.url, cli_args.port)
+    main(cli_args.url, cli_args.port, cli_args.manager_ids)
 
 
 def start_pygame_gui(cli_args):
     from overcooked_simulator.gui_2d_vis.overcooked_gui import main
 
-    main(cli_args.url, cli_args.port)
+    main(cli_args.url, cli_args.port, cli_args.manager_ids)
 
 
 def main(cli_args=None):
@@ -56,6 +57,7 @@ if __name__ == "__main__":
 
     url_and_port_arguments(parser)
     disable_websocket_logging_arguments(parser)
+    add_list_of_manager_ids_arguments(parser)
 
     args = parser.parse_args()
     print(args)
diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
index 35db3885..95e8e18f 100644
--- a/overcooked_simulator/game_server.py
+++ b/overcooked_simulator/game_server.py
@@ -35,7 +35,12 @@ from overcooked_simulator.server_results import (
     PlayerInfo,
     PlayerRequestResult,
 )
-from overcooked_simulator.utils import setup_logging, url_and_port_arguments
+from overcooked_simulator.utils import (
+    setup_logging,
+    url_and_port_arguments,
+    add_list_of_manager_ids_arguments,
+    disable_websocket_logging_arguments,
+)
 
 log = logging.getLogger(__name__)
 
@@ -97,10 +102,11 @@ class EnvironmentHandler:
         """The preferred sleep time between environment steps in nanoseconds based on the `env_step_frequency`."""
         self.client_ids_to_player_hashes = {}
         """A dictionary mapping client IDs to player hashes."""
+        self.allowed_manager: list[str] = []
 
     def create_env(
         self, environment_config: CreateEnvironmentConfig
-    ) -> CreateEnvResult:
+    ) -> CreateEnvResult | int:
         """Create a new environment.
 
         Args:
@@ -110,6 +116,8 @@ class EnvironmentHandler:
             A dictionary containing the created environment ID and player information.
 
         """
+        if environment_config.manager_id not in self.allowed_manager:
+            return 1
         env_id = uuid.uuid4().hex
 
         env = Environment(
@@ -482,6 +490,9 @@ class EnvironmentHandler:
             return True
         return False
 
+    def extend_allowed_manager(self, manager: list[str]):
+        self.allowed_manager.extend(manager)
+
 
 class PlayerConnectionManager:
     """
@@ -540,7 +551,7 @@ class PlayerConnectionManager:
             await connection.send_text(message)
 
 
-manager = PlayerConnectionManager()
+connection_manager = PlayerConnectionManager()
 environment_handler: EnvironmentHandler = EnvironmentHandler()
 
 
@@ -656,6 +667,8 @@ class AdditionalPlayer(BaseModel):
 @app.post("/manage/create_env/")
 async def create_env(creation: CreateEnvironmentConfig) -> CreateEnvResult:
     result = environment_handler.create_env(creation)
+    if result == 1:
+        raise HTTPException(status_code=403, detail="Manager ID not known/registered.")
     return result
 
 
@@ -694,7 +707,7 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
     if not environment_handler.is_known_client_id(client_id):
         log.warning(f"wrong websocket connection with {client_id=}")
         return
-    await manager.connect_player(websocket, client_id)
+    await connection_manager.connect_player(websocket, client_id)
     log.debug(f"Client #{client_id} connected")
     environment_handler.set_player_connected(client_id)
     try:
@@ -703,17 +716,18 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
             answer = manage_websocket_message(message, client_id)
             if isinstance(answer, dict):
                 answer = json.dumps(answer)
-            await manager.send_personal_message(answer, websocket)
+            await connection_manager.send_personal_message(answer, websocket)
 
     except WebSocketDisconnect:
-        manager.disconnect(client_id)
+        connection_manager.disconnect(client_id)
         environment_handler.set_player_disconnected(client_id)
         log.debug(f"Client #{client_id} disconnected")
 
 
-def main(host: str, port: int):
+def main(host: str, port: int, manager_ids: list[str]):
     loop = asyncio.new_event_loop()
     asyncio.set_event_loop(loop)
+    environment_handler.extend_allowed_manager(manager_ids)
     loop.create_task(environment_handler.environment_steps())
     config = uvicorn.Config(app, host=host, port=port, loop=loop)
     server = uvicorn.Server(config)
@@ -728,9 +742,11 @@ if __name__ == "__main__":
     )
 
     url_and_port_arguments(parser)
+    disable_websocket_logging_arguments(parser)
+    add_list_of_manager_ids_arguments(parser)
     args = parser.parse_args()
     setup_logging(args.enable_websocket_logging)
-    main(args.url, args.port)
+    main(args.url, args.port, args.manager_ids)
     """
     Or in console: 
     uvicorn overcooked_simulator.fastapi_game_server:app --reload
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index 1370d847..dd96280b 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -2,6 +2,7 @@ import argparse
 import dataclasses
 import json
 import logging
+import random
 import sys
 from enum import Enum
 
@@ -26,6 +27,7 @@ from overcooked_simulator.utils import (
     setup_logging,
     url_and_port_arguments,
     disable_websocket_logging_arguments,
+    add_list_of_manager_ids_arguments,
 )
 
 
@@ -35,9 +37,6 @@ class MenuStates(Enum):
     End = "End"
 
 
-MANAGER_ID = "1233245425"
-
-
 log = logging.getLogger(__name__)
 
 
@@ -73,7 +72,8 @@ class PyGameGUI:
         player_names: list[str | int],
         player_keys: list[pygame.key],
         url: str,
-        port: str,
+        port: int,
+        manager_ids: list[str],
     ):
         self.game_screen = None
         self.FPS = 60
@@ -93,6 +93,7 @@ class PyGameGUI:
         self.websockets = {}
 
         self.request_url = f"http://{url}:{port}"
+        self.manager_id = random.choice(manager_ids)
 
         # TODO cache loaded images?
         with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
@@ -444,7 +445,7 @@ class PyGameGUI:
         with open(environment_config_path, "r") as file:
             environment_config = file.read()
         creation_json = CreateEnvironmentConfig(
-            manager_id=MANAGER_ID,
+            manager_id=self.manager_id,
             number_players=2,
             environment_settings={"all_player_can_pause_game": False},
             item_info_config=item_info,
@@ -507,7 +508,7 @@ class PyGameGUI:
         requests.post(
             f"{self.request_url}/manage/stop_env",
             json={
-                "manager_id": MANAGER_ID,
+                "manager_id": self.manager_id,
                 "env_id": self.current_env_id,
                 "reason": "reset button pressed",
             },
@@ -520,7 +521,7 @@ class PyGameGUI:
         requests.post(
             f"{self.request_url}/manage/stop_env",
             json={
-                "manager_id": MANAGER_ID,
+                "manager_id": self.manager_id,
                 "env_id": self.current_env_id,
                 "reason": "finish button pressed",
             },
@@ -670,7 +671,7 @@ class PyGameGUI:
         sys.exit()
 
 
-def main(url, port):
+def main(url: str, port: int, manager_ids: list[str]):
     # TODO maybe read the player names and keyboard keys from config file?
     keys1 = [
         pygame.K_LEFT,
@@ -684,7 +685,11 @@ def main(url, port):
 
     number_players = 2
     gui = PyGameGUI(
-        list(map(str, range(number_players))), [keys1, keys2], url=url, port=port
+        list(map(str, range(number_players))),
+        [keys1, keys2],
+        url=url,
+        port=port,
+        manager_ids=manager_ids,
     )
     gui.start_pygame()
 
@@ -698,6 +703,7 @@ if __name__ == "__main__":
 
     url_and_port_arguments(parser)
     disable_websocket_logging_arguments(parser)
+    add_list_of_manager_ids_arguments(parser)
     args = parser.parse_args()
     setup_logging(enable_websocket_logging=args.enable_websocket_logging)
-    main(args.url, args.port)
+    main(args.url, args.port, args.manager_ids)
diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py
index ecfb5958..d99d6430 100644
--- a/overcooked_simulator/utils.py
+++ b/overcooked_simulator/utils.py
@@ -1,6 +1,7 @@
 import logging
 import os
 import sys
+import uuid
 from datetime import datetime
 from enum import Enum
 
@@ -66,3 +67,14 @@ def url_and_port_arguments(parser):
 
 def disable_websocket_logging_arguments(parser):
     parser.add_argument("--enable-websocket-logging", action="store_true", default=True)
+
+
+def add_list_of_manager_ids_arguments(parser):
+    parser.add_argument(
+        "-m",
+        "--manager_ids",
+        nargs="+",
+        type=list[str],
+        default=[uuid.uuid4().hex],
+        help="List of manager IDs that can create environments.",
+    )
-- 
GitLab