diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index 28a4d6f62c63613eb04f6efbd5115a4d4728cf23..2e4dbe9f0dae8a0354fcd9b676b12a6bca7dac46 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -142,6 +142,11 @@ class EnvironmentHandler: self.client_ids_to_player_hashes = {} """A dictionary mapping client IDs to player hashes.""" self.allowed_manager: list[str] = [] + """List of manager ids that are allowed to manage/create environments.""" + self.host: str = "" + """The host string (e.g., localhost) of the game server.""" + self.port: int = 8000 + """The port of the game server.""" def create_env( self, environment_config: CreateEnvironmentConfig @@ -225,6 +230,7 @@ class EnvironmentHandler: "client_id": client_id, "player_hash": player_hash, "player_id": player_id, + "websocket_url": f"ws://{self.host}:{self.port}/ws/player/{client_id}", } def add_player(self, config: AdditionalPlayer) -> dict[str, PlayerInfo]: @@ -560,6 +566,10 @@ class EnvironmentHandler: def extend_allowed_manager(self, manager: list[str]): self.allowed_manager.extend(manager) + def set_host_and_port(self, host, port): + self.host = host + self.port = port + class PlayerConnectionManager: """ @@ -819,6 +829,7 @@ def main( loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) environment_handler.extend_allowed_manager(manager_ids) + environment_handler.set_host_and_port(host=host, port=port) loop.create_task(environment_handler.environment_steps()) config = uvicorn.Config(app, host=host, port=port, loop=loop) server = uvicorn.Server(config) diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index 42a539c247a4ebe0886f6ac2dea5649bf6d7c0a5..ccc4dd105fac8fa47bd192d926e447089ba95893 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -141,7 +141,6 @@ class PyGameGUI: self.key_sets: list[PlayerKeySet] = [] - self.websocket_url = f"ws://{game_host}:{game_port}/ws/player/" self.websockets = {} if CONNECT_WITH_STUDY_SERVER: @@ -1503,7 +1502,7 @@ class PyGameGUI: def create_and_connect_bot(self, player_id, player_info): player_hash = player_info["player_hash"] print( - f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"' + f'--general_plus="agent_websocket:{player_info["websocket_url"]};player_hash:{player_hash};agent_id:{player_id}"' ) if self.USE_AAAMBOS_AGENT: sub = Popen( @@ -1516,7 +1515,7 @@ class PyGameGUI: str(ROOT_DIR / "configs" / "agents" / "arch_config.yml"), "--run_config", str(ROOT_DIR / "configs" / "agents" / "run_config.yml"), - f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"', + f'--general_plus="agent_websocket:{player_info["websocket_url"]};player_hash:{player_hash};agent_id:{player_id}"', f"--instance={player_hash}", ] ), @@ -1528,7 +1527,7 @@ class PyGameGUI: [ "python", str(ROOT_DIR / "configs" / "agents" / "random_agent.py"), - f'--uri {self.websocket_url + player_info["client_id"]}', + f'--uri {player_info["websocket_url"]}', f"--player_hash {player_hash}", f"--player_id {player_id}", ] @@ -1541,7 +1540,7 @@ class PyGameGUI: for p, (player_id, player_info) in enumerate(self.player_info.items()): if p < self.number_humans_to_be_added: # add player websockets - websocket = connect(self.websocket_url + player_info["client_id"]) + websocket = connect(player_info["websocket_url"]) message_dict = { "type": PlayerRequestType.READY.value, "player_hash": player_info["player_hash"], diff --git a/cooperative_cuisine/server_results.py b/cooperative_cuisine/server_results.py index 1134c8fc4c08a60426ac8857855385655be46ad6..a5b85aca0abb344583a8fade1674430703d06568 100644 --- a/cooperative_cuisine/server_results.py +++ b/cooperative_cuisine/server_results.py @@ -18,6 +18,8 @@ class PlayerInfo(TypedDict): """Hash of the player, for validation.""" player_id: str """ID of the player.""" + websocket_url: str + """The url for the websocket to connect to.""" class CreateEnvResult(TypedDict): diff --git a/tests/test_study_server.py b/tests/test_study_server.py index 20776e60b190ee93d4bceec0c89e38c35deb531e..a7996af049076542288800133eb5945b6c40a4c2 100644 --- a/tests/test_study_server.py +++ b/tests/test_study_server.py @@ -20,6 +20,7 @@ def test_valid_post_requests(): "player_id": "0", "client_id": "ksjdhfkjsdfn", "player_hash": "shdfbmsndfb", + "websocket_url": "jhvfshfvbsdnmf", } }, "env_id": "123456789", @@ -48,6 +49,7 @@ def test_valid_post_requests(): "player_id": "0", "client_id": "ksjdhfkjsdfn", "player_hash": "shdfbmsndfb", + "websocket_url": "jhvfshfvbsdnmf", } } @@ -103,6 +105,7 @@ def test_tutorial(): "player_id": "0", "client_id": "ksjdhfkjsdfn", "player_hash": "shdfbmsndfb", + "websocket_url": "jhvfshfvbsdnmf", } }, "env_id": "123456789",