Skip to content
Snippets Groups Projects
Commit 981bfcad authored by Florian Schröder's avatar Florian Schröder
Browse files

Merge branch '107-study-server-sends-complete-websocket-url' into 'dev'

Resolve "Study Server sends complete websocket url"

Closes #107

See merge request scs/cocosy/overcooked-simulator!82
parents d81ac175 c2c310b9
No related branches found
No related tags found
1 merge request!82Resolve "Study Server sends complete websocket url"
Pipeline #48419 passed
......@@ -142,6 +142,11 @@ class EnvironmentHandler:
self.client_ids_to_player_hashes = {}
"""A dictionary mapping client IDs to player hashes."""
self.allowed_manager: list[str] = []
"""List of manager ids that are allowed to manage/create environments."""
self.host: str = ""
"""The host string (e.g., localhost) of the game server."""
self.port: int = 8000
"""The port of the game server."""
def create_env(
self, environment_config: CreateEnvironmentConfig
......@@ -225,6 +230,7 @@ class EnvironmentHandler:
"client_id": client_id,
"player_hash": player_hash,
"player_id": player_id,
"websocket_url": f"ws://{self.host}:{self.port}/ws/player/{client_id}",
}
def add_player(self, config: AdditionalPlayer) -> dict[str, PlayerInfo]:
......@@ -560,6 +566,10 @@ class EnvironmentHandler:
def extend_allowed_manager(self, manager: list[str]):
self.allowed_manager.extend(manager)
def set_host_and_port(self, host, port):
self.host = host
self.port = port
class PlayerConnectionManager:
"""
......@@ -819,6 +829,7 @@ def main(
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
environment_handler.extend_allowed_manager(manager_ids)
environment_handler.set_host_and_port(host=host, port=port)
loop.create_task(environment_handler.environment_steps())
config = uvicorn.Config(app, host=host, port=port, loop=loop)
server = uvicorn.Server(config)
......
......@@ -141,7 +141,6 @@ class PyGameGUI:
self.key_sets: list[PlayerKeySet] = []
self.websocket_url = f"ws://{game_host}:{game_port}/ws/player/"
self.websockets = {}
if CONNECT_WITH_STUDY_SERVER:
......@@ -1503,7 +1502,7 @@ class PyGameGUI:
def create_and_connect_bot(self, player_id, player_info):
player_hash = player_info["player_hash"]
print(
f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"'
f'--general_plus="agent_websocket:{player_info["websocket_url"]};player_hash:{player_hash};agent_id:{player_id}"'
)
if self.USE_AAAMBOS_AGENT:
sub = Popen(
......@@ -1516,7 +1515,7 @@ class PyGameGUI:
str(ROOT_DIR / "configs" / "agents" / "arch_config.yml"),
"--run_config",
str(ROOT_DIR / "configs" / "agents" / "run_config.yml"),
f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"',
f'--general_plus="agent_websocket:{player_info["websocket_url"]};player_hash:{player_hash};agent_id:{player_id}"',
f"--instance={player_hash}",
]
),
......@@ -1528,7 +1527,7 @@ class PyGameGUI:
[
"python",
str(ROOT_DIR / "configs" / "agents" / "random_agent.py"),
f'--uri {self.websocket_url + player_info["client_id"]}',
f'--uri {player_info["websocket_url"]}',
f"--player_hash {player_hash}",
f"--player_id {player_id}",
]
......@@ -1541,7 +1540,7 @@ class PyGameGUI:
for p, (player_id, player_info) in enumerate(self.player_info.items()):
if p < self.number_humans_to_be_added:
# add player websockets
websocket = connect(self.websocket_url + player_info["client_id"])
websocket = connect(player_info["websocket_url"])
message_dict = {
"type": PlayerRequestType.READY.value,
"player_hash": player_info["player_hash"],
......
......@@ -18,6 +18,8 @@ class PlayerInfo(TypedDict):
"""Hash of the player, for validation."""
player_id: str
"""ID of the player."""
websocket_url: str
"""The url for the websocket to connect to."""
class CreateEnvResult(TypedDict):
......
......@@ -20,6 +20,7 @@ def test_valid_post_requests():
"player_id": "0",
"client_id": "ksjdhfkjsdfn",
"player_hash": "shdfbmsndfb",
"websocket_url": "jhvfshfvbsdnmf",
}
},
"env_id": "123456789",
......@@ -48,6 +49,7 @@ def test_valid_post_requests():
"player_id": "0",
"client_id": "ksjdhfkjsdfn",
"player_hash": "shdfbmsndfb",
"websocket_url": "jhvfshfvbsdnmf",
}
}
......@@ -103,6 +105,7 @@ def test_tutorial():
"player_id": "0",
"client_id": "ksjdhfkjsdfn",
"player_hash": "shdfbmsndfb",
"websocket_url": "jhvfshfvbsdnmf",
}
},
"env_id": "123456789",
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment