diff --git a/README.md b/README.md index 135e869b284165af0d9eebe765a90478cc7f1f31..4e7e315bc87854f056c643808796056629f27e75 100644 --- a/README.md +++ b/README.md @@ -35,13 +35,17 @@ _The arguments are the defaults. Therefore, they are optional._ You can also start the **Game Server** and the **PyGame GUI** individually in different terminals. ```bash -python3 overcooked_simulator/game_server.py --url "localhost" --port 8000 +python3 overcooked_simulator/game_server.py --url "localhost" --port 8000 --manager_ids SECRETKEY1 SECRETKEY2 -python3 overcooked_simulator/gui_2d_vis/overcooked_simulator.py --url "localhost" --port 8000 +python3 overcooked_simulator/gui_2d_vis/overcooked_gui.py --url "localhost" --port 8000 --manager_ids SECRETKEY1 ``` You can start also several GUIs. +```bash +python3 overcooked_simulator/gui_2d_vis/overcooked_gui.py --url "localhost" --port 8000 --manager_ids SECRETKEY2 +``` + You can replace the GUI with your own GUI (+ study server/matchmaking server). ### Library Installation diff --git a/overcooked_simulator/__init__.py b/overcooked_simulator/__init__.py index fb941f2fadd719d7f44a3c1977aac8ee5f416538..bf5ca5ca27fdd3b55833bca96e56d62859e8a5c3 100644 --- a/overcooked_simulator/__init__.py +++ b/overcooked_simulator/__init__.py @@ -35,9 +35,10 @@ _The arguments are the defaults. Therefore, they are optional._ You can also start the **Game Server** and the **PyGame GUI** individually in different terminals. ```bash -python3 overcooked_simulator/game_server.py --url "localhost" --port 8000 +python3 overcooked_simulator/game_server.py --url "localhost" --port 8000 --manager_ids SECRETKEY1 SECRETKEY2 -python3 overcooked_simulator/gui_2d_vis/overcooked_simulator.py --url "localhost" --port 8000 +python3 overcooked_simulator/gui_2d_vis/overcooked_gui.py --url "localhost" --port 8000 --manager_ids SECRETKEY1 +``` ## Connect with agent and receive game state ... diff --git a/overcooked_simulator/__main__.py b/overcooked_simulator/__main__.py index f81398ece39babd670716b79b045ee3226176e97..f68c506d8598d9cd380b869a19ae9f283bd1b28b 100644 --- a/overcooked_simulator/__main__.py +++ b/overcooked_simulator/__main__.py @@ -5,19 +5,20 @@ from multiprocessing import Process from overcooked_simulator.utils import ( url_and_port_arguments, disable_websocket_logging_arguments, + add_list_of_manager_ids_arguments, ) def start_game_server(cli_args): from overcooked_simulator.game_server import main - main(cli_args.url, cli_args.port) + main(cli_args.url, cli_args.port, cli_args.manager_ids) def start_pygame_gui(cli_args): from overcooked_simulator.gui_2d_vis.overcooked_gui import main - main(cli_args.url, cli_args.port) + main(cli_args.url, cli_args.port, cli_args.manager_ids) def main(cli_args=None): @@ -56,6 +57,7 @@ if __name__ == "__main__": url_and_port_arguments(parser) disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) args = parser.parse_args() print(args) diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index 35db3885b24d70c21ee3c37db3a95ddad423613c..95e8e18fb6733f319d85fc58dfcef18471e66521 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -35,7 +35,12 @@ from overcooked_simulator.server_results import ( PlayerInfo, PlayerRequestResult, ) -from overcooked_simulator.utils import setup_logging, url_and_port_arguments +from overcooked_simulator.utils import ( + setup_logging, + url_and_port_arguments, + add_list_of_manager_ids_arguments, + disable_websocket_logging_arguments, +) log = logging.getLogger(__name__) @@ -97,10 +102,11 @@ class EnvironmentHandler: """The preferred sleep time between environment steps in nanoseconds based on the `env_step_frequency`.""" self.client_ids_to_player_hashes = {} """A dictionary mapping client IDs to player hashes.""" + self.allowed_manager: list[str] = [] def create_env( self, environment_config: CreateEnvironmentConfig - ) -> CreateEnvResult: + ) -> CreateEnvResult | int: """Create a new environment. Args: @@ -110,6 +116,8 @@ class EnvironmentHandler: A dictionary containing the created environment ID and player information. """ + if environment_config.manager_id not in self.allowed_manager: + return 1 env_id = uuid.uuid4().hex env = Environment( @@ -482,6 +490,9 @@ class EnvironmentHandler: return True return False + def extend_allowed_manager(self, manager: list[str]): + self.allowed_manager.extend(manager) + class PlayerConnectionManager: """ @@ -540,7 +551,7 @@ class PlayerConnectionManager: await connection.send_text(message) -manager = PlayerConnectionManager() +connection_manager = PlayerConnectionManager() environment_handler: EnvironmentHandler = EnvironmentHandler() @@ -656,6 +667,8 @@ class AdditionalPlayer(BaseModel): @app.post("/manage/create_env/") async def create_env(creation: CreateEnvironmentConfig) -> CreateEnvResult: result = environment_handler.create_env(creation) + if result == 1: + raise HTTPException(status_code=403, detail="Manager ID not known/registered.") return result @@ -694,7 +707,7 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str): if not environment_handler.is_known_client_id(client_id): log.warning(f"wrong websocket connection with {client_id=}") return - await manager.connect_player(websocket, client_id) + await connection_manager.connect_player(websocket, client_id) log.debug(f"Client #{client_id} connected") environment_handler.set_player_connected(client_id) try: @@ -703,17 +716,18 @@ async def websocket_player_endpoint(websocket: WebSocket, client_id: str): answer = manage_websocket_message(message, client_id) if isinstance(answer, dict): answer = json.dumps(answer) - await manager.send_personal_message(answer, websocket) + await connection_manager.send_personal_message(answer, websocket) except WebSocketDisconnect: - manager.disconnect(client_id) + connection_manager.disconnect(client_id) environment_handler.set_player_disconnected(client_id) log.debug(f"Client #{client_id} disconnected") -def main(host: str, port: int): +def main(host: str, port: int, manager_ids: list[str]): loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) + environment_handler.extend_allowed_manager(manager_ids) loop.create_task(environment_handler.environment_steps()) config = uvicorn.Config(app, host=host, port=port, loop=loop) server = uvicorn.Server(config) @@ -728,9 +742,11 @@ if __name__ == "__main__": ) url_and_port_arguments(parser) + disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) args = parser.parse_args() setup_logging(args.enable_websocket_logging) - main(args.url, args.port) + main(args.url, args.port, args.manager_ids) """ Or in console: uvicorn overcooked_simulator.fastapi_game_server:app --reload diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py index 3a07ad1bc7856f0b3597f8d48f411a894f7db5d5..b6843fc61b72f315fd2e573d149e973c84037576 100644 --- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py +++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py @@ -2,6 +2,7 @@ import argparse import dataclasses import json import logging +import random import sys from enum import Enum @@ -26,6 +27,7 @@ from overcooked_simulator.utils import ( setup_logging, url_and_port_arguments, disable_websocket_logging_arguments, + add_list_of_manager_ids_arguments, ) @@ -35,9 +37,6 @@ class MenuStates(Enum): End = "End" -MANAGER_ID = "1233245425" - - log = logging.getLogger(__name__) @@ -73,7 +72,8 @@ class PyGameGUI: player_names: list[str | int], player_keys: list[pygame.key], url: str, - port: str, + port: int, + manager_ids: list[str], ): self.game_screen = None self.FPS = 60 @@ -93,6 +93,7 @@ class PyGameGUI: self.websockets = {} self.request_url = f"http://{url}:{port}" + self.manager_id = random.choice(manager_ids) # TODO cache loaded images? with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file: @@ -447,7 +448,7 @@ class PyGameGUI: with open(environment_config_path, "r") as file: environment_config = file.read() creation_json = CreateEnvironmentConfig( - manager_id=MANAGER_ID, + manager_id=self.manager_id, number_players=2, environment_settings={"all_player_can_pause_game": False}, item_info_config=item_info, @@ -459,6 +460,8 @@ class PyGameGUI: f"{self.request_url}/manage/create_env/", json=creation_json, ) + if env_info.status_code == 403: + raise ValueError(f"Forbidden Request: {env_info.json()['detail']}") env_info = env_info.json() assert isinstance(env_info, dict), "Env info must be a dictionary" self.current_env_id = env_info["env_id"] @@ -510,7 +513,7 @@ class PyGameGUI: requests.post( f"{self.request_url}/manage/stop_env", json={ - "manager_id": MANAGER_ID, + "manager_id": self.manager_id, "env_id": self.current_env_id, "reason": "reset button pressed", }, @@ -523,7 +526,7 @@ class PyGameGUI: requests.post( f"{self.request_url}/manage/stop_env", json={ - "manager_id": MANAGER_ID, + "manager_id": self.manager_id, "env_id": self.current_env_id, "reason": "finish button pressed", }, @@ -673,7 +676,7 @@ class PyGameGUI: sys.exit() -def main(url, port): +def main(url: str, port: int, manager_ids: list[str]): # TODO maybe read the player names and keyboard keys from config file? keys1 = [ pygame.K_LEFT, @@ -687,7 +690,11 @@ def main(url, port): number_players = 2 gui = PyGameGUI( - list(map(str, range(number_players))), [keys1, keys2], url=url, port=port + list(map(str, range(number_players))), + [keys1, keys2], + url=url, + port=port, + manager_ids=manager_ids, ) gui.start_pygame() @@ -701,6 +708,7 @@ if __name__ == "__main__": url_and_port_arguments(parser) disable_websocket_logging_arguments(parser) + add_list_of_manager_ids_arguments(parser) args = parser.parse_args() setup_logging(enable_websocket_logging=args.enable_websocket_logging) - main(args.url, args.port) + main(args.url, args.port, args.manager_ids) diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index ecfb5958a982c084ce46db849fa9568483fea3e5..2d77d32145e43db8bb5c1960067453a970a39080 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -1,6 +1,7 @@ import logging import os import sys +import uuid from datetime import datetime from enum import Enum @@ -65,4 +66,17 @@ def url_and_port_arguments(parser): def disable_websocket_logging_arguments(parser): - parser.add_argument("--enable-websocket-logging", action="store_true", default=True) + parser.add_argument( + "--enable-websocket-logging" "", action="store_true", default=True + ) + + +def add_list_of_manager_ids_arguments(parser): + parser.add_argument( + "-m", + "--manager_ids", + nargs="+", + type=str, + default=[uuid.uuid4().hex], + help="List of manager IDs that can create environments.", + )