diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py index a080070b23a716556ca33841cf0226928c52334a..b35819efdc2abf2c04f02413368c63ecc8176e3f 100644 --- a/cooperative_cuisine/action.py +++ b/cooperative_cuisine/action.py @@ -2,7 +2,7 @@ from __future__ import annotations import dataclasses from enum import Enum -from typing import Literal +from typing import Literal, TypedDict from numpy import typing as npt @@ -49,3 +49,11 @@ class Action: self.action_type = ActionType(self.action_type) if isinstance(self.action_data, str) and self.action_data != "pickup": self.action_data = InterActionData(self.action_data) + + +class ActionDict(TypedDict): + """Typed dict corresponding to the Action dataclass.""" + player: str + action_type: ActionType + action_data: list[float] | InterActionData | Literal["pickup"] + duration: float diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index b8df259a10c1bc5dc51d5fd38383faf704cb2fe2..66922d2ec7ad5aed546efaa33af17dc61dffd02f 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -29,7 +29,7 @@ from pydantic import BaseModel from starlette.websockets import WebSocketDisconnect from typing_extensions import TypedDict -from cooperative_cuisine.action import Action +from cooperative_cuisine.action import Action, ActionDict from cooperative_cuisine.environment import Environment from cooperative_cuisine.server_results import ( CreateEnvResult, @@ -591,6 +591,12 @@ class PlayerRequestType(Enum): """Indicates a request to pass an action of a player to the environment.""" +class WebsocketMessage(TypedDict): + type: str + action: None | ActionDict + player_hash: str + + def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResult | str: """Manage WebSocket Message by validating the message and passing it to the environment. diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index 804315bbd2ca63896d3e5d7fc483de9c5b808e12..dee2e03776fb50e21f6c2d6149e7a0914d5d83e2 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -21,7 +21,11 @@ from websockets.sync.client import connect from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.action import ActionType, InterActionData, Action -from cooperative_cuisine.game_server import CreateEnvironmentConfig +from cooperative_cuisine.game_server import ( + CreateEnvironmentConfig, + WebsocketMessage, + PlayerRequestType, +) from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer from cooperative_cuisine.pygame_2d_vis.game_colors import colors from cooperative_cuisine.state_representation import StateRepresentation @@ -1366,12 +1370,11 @@ class PyGameGUI: if p < self.number_humans_to_be_added: # add player websockets websocket = connect(self.websocket_url + player_info["client_id"]) - - websocket.send( - json.dumps( - {"type": "ready", "player_hash": player_info["player_hash"]} - ) + ws_message = WebsocketMessage( + type=PlayerRequestType.READY.value, + player_hash=player_info["player_hash"], ) + websocket.send(json.dumps(ws_message)) assert ( json.loads(websocket.recv())["status"] == 200 ), "not accepted player" @@ -1486,30 +1489,20 @@ class PyGameGUI: float(action.action_data[1]), ] - self.websockets[action.player].send( - json.dumps( - { - "type": "action", - "action": dataclasses.asdict( - action, dict_factory=custom_asdict_factory - ), - "player_hash": self.player_info[action.player]["player_hash"], - } - ) + ws_message = WebsocketMessage( + type=PlayerRequestType.ACTION.value, + action=dataclasses.asdict(action, dict_factory=custom_asdict_factory), + player_hash=self.player_info[action.player]["player_hash"], ) + self.websockets[action.player].send(json.dumps(ws_message)) self.websockets[action.player].recv() def request_state(self): - self.websockets[self.state_player_id].send( - json.dumps( - { - "type": "get_state", - "player_hash": self.player_info[self.state_player_id][ - "player_hash" - ], - } - ) + ws_message = WebsocketMessage( + type=PlayerRequestType.GET_STATE.value, + player_hash=self.player_info[self.state_player_id]["player_hash"], ) + self.websockets[self.state_player_id].send(json.dumps(ws_message)) state = json.loads(self.websockets[self.state_player_id].recv()) return state diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py index 793776a9efb08be8696e7b08b1b89f54d7b14e8c..06b7b2ccf4e2f20a7216bfcbdbed4e9e804accb7 100644 --- a/cooperative_cuisine/study_server.py +++ b/cooperative_cuisine/study_server.py @@ -31,7 +31,7 @@ from typing_extensions import TypedDict from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.environment import EnvironmentConfig -from cooperative_cuisine.game_server import CreateEnvironmentConfig +from cooperative_cuisine.game_server import CreateEnvironmentConfig, EnvironmentData from cooperative_cuisine.server_results import PlayerInfo from cooperative_cuisine.utils import ( url_and_port_arguments, @@ -86,16 +86,20 @@ class StudyState: self.study_config: StudyConfig = yaml.load( str(env_config_f), Loader=yaml.SafeLoader ) + """Configuration for the study which layouts, env_configs and item infos are used for the study levels.""" self.levels: list[LevelConfig] = self.study_config["levels"] + """List of level configs for each of the levels which the study runs through.""" self.current_level_idx: int = 0 - + """Counter of which level is currently run in the config.""" self.participant_id_to_player_info = {} - self.player_ids = {} + """A dictionary which maps participants to player infos.""" self.num_connected_players: int = 0 + """Number of currently connected players.""" - self.current_running_env = None - self.next_level_env = None + self.current_running_env: EnvironmentData | None = None + """Information about the current running environment.""" self.players_done = {} + """A dictionary which saves which player has sent ready.""" self.use_aaambos_agent = False @@ -119,7 +123,7 @@ class StudyState: len(self.participant_id_to_player_info) == self.study_config["num_players"] ) - def can_add_participant(self, num_participants: int) -> bool: + def can_add_participants(self, num_participants: int) -> bool: filled = ( self.num_connected_players + num_participants <= self.study_config["num_players"] @@ -324,12 +328,12 @@ class StudyManager: def add_participant(self, participant_id: str, number_players: int): if not self.running_studies or all( - [not s.can_add_participant(number_players) for s in self.running_studies] + [not s.can_add_participants(number_players) for s in self.running_studies] ): self.create_study() for study in self.running_studies: - if study.can_add_participant(number_players): + if study.can_add_participants(number_players): player_info = study.add_participant(participant_id, number_players) self.participant_id_to_study_map[participant_id] = study return True, player_info @@ -345,13 +349,15 @@ class StudyManager: def get_participant_game_connection( self, participant_id: str - ) -> Tuple[PlayerInfo | None, LevelInfo | None]: + ) -> Tuple[bool, None | Tuple[PlayerInfo, LevelInfo]]: + can_connect = False if participant_id in self.participant_id_to_study_map.keys(): assigned_study = self.participant_id_to_study_map[participant_id] player_info, level_info = assigned_study.get_connection(participant_id) - return player_info, level_info + can_connect = True + return can_connect, (player_info, level_info) else: - return None, None + return can_connect, None def set_game_server_url(self, game_host: str, game_port: str): self.game_host = game_host @@ -393,15 +399,14 @@ async def level_done(participant_id: str) -> JSONResponse: @app.post("/get_game_connection/{participant_id}") async def get_game_connection(participant_id: str) -> JSONResponse: - player_info, level_info = study_manager.get_participant_game_connection( - participant_id - ) - if player_info and level_info: + can_connect, data = study_manager.get_participant_game_connection(participant_id) + if can_connect: + player_info, level_info = data return JSONResponse( content={"player_info": player_info, "level_info": level_info} ) else: - raise HTTPException(status_code=409, detail="Not valid game connection.") + raise HTTPException(status_code=409, detail="No valid game connection.") @app.post("/connect_to_tutorial/{participant_id}")