From 8d82c0860e465de822c6159751189c4b58ecb29e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Schr=C3=B6der?= <fschroeder@techfak.uni-bielefeld.de> Date: Sun, 28 Jan 2024 12:37:00 +0100 Subject: [PATCH] Add manager ID argument to Overcooked simulator An optional argument for the list of manager IDs was added across the Overcooked simulator. These IDs are now used for authorization purposes when creating environments. Unauthorized attempts now return HTTP error 403. This helps enhance security and control over environment management in the simulation. --- overcooked_simulator/__main__.py | 6 ++-- overcooked_simulator/game_server.py | 32 ++++++++++++++----- .../gui_2d_vis/overcooked_gui.py | 26 +++++++++------ overcooked_simulator/utils.py | 12 +++++++ 4 files changed, 56 insertions(+), 20 deletions(-) diff --git a/overcooked_simulator/__main__.py b/overcooked_simulator/__main__.py index f81398ec..f68c506d 100644 --- a/overcooked_simulator/__main__.py +++ b/overcooked_simulator/__main__.py @@ -5,19 +5,20 @@ from multiprocessing import Process from overcooked_simulator.utils import ( url_and_port_arguments, disable_websocket_logging_arguments, + add_list_of_manager_ids_arguments, ) def start_game_server(cli_args): from overcooked_simulator.game_server import main - main(cli_args.url, cli_args.port) + main(cli_args.url, cli_args.port, cli_args.manager_ids) def start_pygame_gui(cli_args): from overcooked_simulator.gui_2d_vis.overcooked_gui import main - main(cli_args.url, cli_args.port) + main(cli_args.url, cli_args.port, cli_args.manager_ids) def main(cli_args=None): @@ -56,6 +57,7 @@ if __name__ == "__main__": url_and_port_arguments(parser) disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) args = parser.parse_args() print(args) diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index 35db3885..95e8e18f 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -35,7 +35,12 @@ from overcooked_simulator.server_results import ( PlayerInfo, PlayerRequestResult, ) -from overcooked_simulator.utils import setup_logging, url_and_port_arguments +from overcooked_simulator.utils import ( + setup_logging, + url_and_port_arguments, + add_list_of_manager_ids_arguments, + disable_websocket_logging_arguments, +) log = logging.getLogger(__name__) @@ -97,10 +102,11 @@ class EnvironmentHandler: """The preferred sleep time between environment steps in nanoseconds based on the `env_step_frequency`.""" self.client_ids_to_player_hashes = {} """A dictionary mapping client IDs to player hashes.""" + self.allowed_manager: list[str] = [] def create_env( self, environment_config: CreateEnvironmentConfig - ) -> CreateEnvResult: + ) -> CreateEnvResult | int: """Create a new environment. Args: @@ -110,6 +116,8 @@ class EnvironmentHandler: A dictionary containing the created environment ID and player information. """ + if environment_config.manager_id not in self.allowed_manager: + return 1 env_id = uuid.uuid4().hex env = Environment( @@ -482,6 +490,9 @@ class EnvironmentHandler: return True return False + def extend_allowed_manager(self, manager: list[str]): + self.allowed_manager.extend(manager) + class PlayerConnectionManager: """ @@ -540,7 +551,7 @@ class PlayerConnectionManager: await connection.send_text(message) -manager = PlayerConnectionManager() +connection_manager = PlayerConnectionManager() environment_handler: EnvironmentHandler = EnvironmentHandler() @@ -656,6 +667,8 @@ class AdditionalPlayer(BaseModel): @app.post("/manage/create_env/") async def create_env(creation: CreateEnvironmentConfig) -> CreateEnvResult: result = environment_handler.create_env(creation) + if result == 1: + raise HTTPException(status_code=403, detail="Manager ID not known/registered.") return result @@ -694,7 +707,7 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str): if not environment_handler.is_known_client_id(client_id): log.warning(f"wrong websocket connection with {client_id=}") return - await manager.connect_player(websocket, client_id) + await connection_manager.connect_player(websocket, client_id) log.debug(f"Client #{client_id} connected") environment_handler.set_player_connected(client_id) try: @@ -703,17 +716,18 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str): answer = manage_websocket_message(message, client_id) if isinstance(answer, dict): answer = json.dumps(answer) - await manager.send_personal_message(answer, websocket) + await connection_manager.send_personal_message(answer, websocket) except WebSocketDisconnect: - manager.disconnect(client_id) + connection_manager.disconnect(client_id) environment_handler.set_player_disconnected(client_id) log.debug(f"Client #{client_id} disconnected") -def main(host: str, port: int): +def main(host: str, port: int, manager_ids: list[str]): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + environment_handler.extend_allowed_manager(manager_ids) loop.create_task(environment_handler.environment_steps()) config = uvicorn.Config(app, host=host, port=port, loop=loop) server = uvicorn.Server(config) @@ -728,9 +742,11 @@ if __name__ == "__main__": ) url_and_port_arguments(parser) + disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) args = parser.parse_args() setup_logging(args.enable_websocket_logging) - main(args.url, args.port) + main(args.url, args.port, args.manager_ids) """ Or in console: uvicorn overcooked_simulator.fastapi_game_server:app --reload diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py index 1370d847..dd96280b 100644 --- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py +++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py @@ -2,6 +2,7 @@ import argparse import dataclasses import json import logging +import random import sys from enum import Enum @@ -26,6 +27,7 @@ from overcooked_simulator.utils import ( setup_logging, url_and_port_arguments, disable_websocket_logging_arguments, + add_list_of_manager_ids_arguments, ) @@ -35,9 +37,6 @@ class MenuStates(Enum): End = "End" -MANAGER_ID = "1233245425" - - log = logging.getLogger(__name__) @@ -73,7 +72,8 @@ class PyGameGUI: player_names: list[str | int], player_keys: list[pygame.key], url: str, - port: str, + port: int, + manager_ids: list[str], ): self.game_screen = None self.FPS = 60 @@ -93,6 +93,7 @@ class PyGameGUI: self.websockets = {} self.request_url = f"http://{url}:{port}" + self.manager_id = random.choice(manager_ids) # TODO cache loaded images? with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file: @@ -444,7 +445,7 @@ class PyGameGUI: with open(environment_config_path, "r") as file: environment_config = file.read() creation_json = CreateEnvironmentConfig( - manager_id=MANAGER_ID, + manager_id=self.manager_id, number_players=2, environment_settings={"all_player_can_pause_game": False}, item_info_config=item_info, @@ -507,7 +508,7 @@ class PyGameGUI: requests.post( f"{self.request_url}/manage/stop_env", json={ - "manager_id": MANAGER_ID, + "manager_id": self.manager_id, "env_id": self.current_env_id, "reason": "reset button pressed", }, @@ -520,7 +521,7 @@ class PyGameGUI: requests.post( f"{self.request_url}/manage/stop_env", json={ - "manager_id": MANAGER_ID, + "manager_id": self.manager_id, "env_id": self.current_env_id, "reason": "finish button pressed", }, @@ -670,7 +671,7 @@ class PyGameGUI: sys.exit() -def main(url, port): +def main(url: str, port: int, manager_ids: list[str]): # TODO maybe read the player names and keyboard keys from config file? keys1 = [ pygame.K_LEFT, @@ -684,7 +685,11 @@ def main(url, port): number_players = 2 gui = PyGameGUI( - list(map(str, range(number_players))), [keys1, keys2], url=url, port=port + list(map(str, range(number_players))), + [keys1, keys2], + url=url, + port=port, + manager_ids=manager_ids, ) gui.start_pygame() @@ -698,6 +703,7 @@ if __name__ == "__main__": url_and_port_arguments(parser) disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) args = parser.parse_args() setup_logging(enable_websocket_logging=args.enable_websocket_logging) - main(args.url, args.port) + main(args.url, args.port, args.manager_ids) diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index ecfb5958..d99d6430 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -1,6 +1,7 @@ import logging import os import sys +import uuid from datetime import datetime from enum import Enum @@ -66,3 +67,14 @@ def url_and_port_arguments(parser): def disable_websocket_logging_arguments(parser): parser.add_argument("--enable-websocket-logging", action="store_true", default=True) + + +def add_list_of_manager_ids_arguments(parser): + parser.add_argument( + "-m", + "--manager_ids", + nargs="+", + type=list[str], + default=[uuid.uuid4().hex], + help="List of manager IDs that can create environments.", + ) -- GitLab