diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py index a080070b23a716556ca33841cf0226928c52334a..bb94a9369ddb0293464715fbb19f7edcbbdebb1f 100644 --- a/cooperative_cuisine/action.py +++ b/cooperative_cuisine/action.py @@ -36,7 +36,7 @@ class Action: """Id of the player.""" action_type: ActionType """Type of the action to perform. Defines what action data is valid.""" - action_data: npt.NDArray[float] | InterActionData | Literal["pickup"] + action_data: npt.NDArray[float] | list[float] | InterActionData | Literal["pickup"] """Data for the action, e.g., movement vector or start and stop interaction.""" duration: float | int = 0 """Duration of the action (relevant for movement)""" diff --git a/cooperative_cuisine/configs/environment_config.yaml b/cooperative_cuisine/configs/environment_config.yaml index 84b3318b9f6b72f9650f4ee506fc05109ae3f080..9a71c3766af6744364c63721649fbf2c6c00639e 100644 --- a/cooperative_cuisine/configs/environment_config.yaml +++ b/cooperative_cuisine/configs/environment_config.yaml @@ -7,6 +7,7 @@ plates: game: time_limit_seconds: 300 undo_dispenser_pickup: true + validate_recipes: false meals: all: true diff --git a/cooperative_cuisine/configs/study/level1/level1_config.yaml b/cooperative_cuisine/configs/study/level1/level1_config.yaml index 10b787920a1421152003690f2080a272774ee58d..788921d229b341f8d6b77d5ad9eacb4daccbf66f 100644 --- a/cooperative_cuisine/configs/study/level1/level1_config.yaml +++ b/cooperative_cuisine/configs/study/level1/level1_config.yaml @@ -8,6 +8,7 @@ plates: game: time_limit_seconds: 300 undo_dispenser_pickup: true + validate_recipes: false meals: all: false diff --git a/cooperative_cuisine/configs/study/level2/level2_config.yaml b/cooperative_cuisine/configs/study/level2/level2_config.yaml index 6d5c16fd044ed82cf2ea387f2102a2fdf4f9e940..951dc6a4120b39d77d9fcc981349c68c87e79b35 100644 --- a/cooperative_cuisine/configs/study/level2/level2_config.yaml +++ b/cooperative_cuisine/configs/study/level2/level2_config.yaml @@ -7,6 +7,7 @@ plates: game: time_limit_seconds: 300 undo_dispenser_pickup: true + validate_recipes: false meals: diff --git a/cooperative_cuisine/environment.py b/cooperative_cuisine/environment.py index c62821d34c1038e25f11037c95af131d0aa92e10..8f29213a50abf5aaf43ae06b0c7ef5b8c5fa49c3 100644 --- a/cooperative_cuisine/environment.py +++ b/cooperative_cuisine/environment.py @@ -75,7 +75,9 @@ log = logging.getLogger(__name__) class EnvironmentConfig(TypedDict): plates: PlateConfig game: dict[ - Literal["time_limit_seconds"] | Literal["undo_dispenser_pickup"], int | bool + Literal["time_limit_seconds"] | Literal["undo_dispenser_pickup"], + int | bool, + bool, ] meals: dict[Literal["all"] | Literal["list"], bool | list[str]] orders: OrderConfig @@ -228,6 +230,12 @@ class Environment: self.overwrite_counters(self.counters) + do_validation = ( + self.environment_config["game"]["validate_recipes"] + if "validate_recipes" in self.environment_config["game"].keys() + else True + ) + self.recipe_validation = Validation( meals=[m for m in self.item_info.values() if m.type == ItemType.Meal] if self.environment_config["meals"]["all"] @@ -238,10 +246,11 @@ class Environment: ], item_info=self.item_info, order_manager=self.order_manager, + do_validation=do_validation, ) meals_to_be_ordered = self.recipe_validation.validate_environment(self.counters) - # assert meals_to_be_ordered, "Need possible meals for order generation." + assert meals_to_be_ordered, "Need possible meals for order generation." available_meals = {meal: self.item_info[meal] for meal in meals_to_be_ordered} self.order_manager.set_available_meals(available_meals) diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py index c8ee856b20355daab0537105c5e9f58322dede6c..57bba2b7ca5552bc1993b8e887df84088e27ccfe 100644 --- a/cooperative_cuisine/game_server.py +++ b/cooperative_cuisine/game_server.py @@ -125,8 +125,6 @@ class EnvironmentHandler: return 1 env_id = uuid.uuid4().hex - print("GAME SERVER ALLOWED IDS:", self.allowed_manager) - env = Environment( env_config=environment_config.environment_config, layout_config=environment_config.layout_config, @@ -597,6 +595,15 @@ class PlayerRequestType(Enum): """Indicates a request to pass an action of a player to the environment.""" +class WebsocketMessage(BaseModel): + type: str + action: None | Action + player_hash: str + + class Config: + arbitrary_types_allowed = True + + def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResult | str: """Manage WebSocket Message by validating the message and passing it to the environment. @@ -610,18 +617,14 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul message_dict = json.loads(message) request_type = None try: - assert "type" in message_dict, "message needs a type" - - request_type = PlayerRequestType(message_dict["type"]) - assert ( - "player_hash" in message_dict - ), "'player_hash' key not in message dictionary'" + ws_message = WebsocketMessage(**message_dict) + request_type = PlayerRequestType(ws_message.type) match request_type: case PlayerRequestType.GET_STATE: - state = environment_handler.get_state(message_dict["player_hash"]) + state = environment_handler.get_state(ws_message.player_hash) if isinstance(state, int): return { - "request_type": message_dict["type"], + "request_type": ws_message.type, "status": 400, "msg": "env id of player not in running envs" if state == 2 @@ -630,40 +633,33 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul } return state case PlayerRequestType.READY: - accepted = environment_handler.set_player_ready( - message_dict["player_hash"] - ) + accepted = environment_handler.set_player_ready(ws_message.player_hash) return { "request_type": request_type.value, "msg": f"ready{' ' if accepted else ' not '}accepted", "status": 200 if accepted else 400, - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } case PlayerRequestType.ACTION: - assert ( - "action" in message_dict - ), "'action' key not in message dictionary of 'action' request" - assert ( - "action_data" in message_dict["action"] - ), "'action_data' key not in message dictionary['action'] of 'action' request" - if isinstance(message_dict["action"]["action_data"], list): - message_dict["action"]["action_data"] = np.array( - message_dict["action"]["action_data"], dtype=float + assert ws_message.action is not None + if isinstance(ws_message.action.action_data, list): + ws_message.action.action_data = np.array( + ws_message.action.action_data, dtype=float ) accepted = environment_handler.player_action( - message_dict["player_hash"], Action(**message_dict["action"]) + ws_message.player_hash, ws_message.action ) return { "request_type": request_type.value, "status": 200 if accepted else 400, "msg": f"action{' ' if accepted else ' not '}accepted", - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } return { "request_type": request_type.value, "status": 400, "msg": "request not handled", - "player_hash": message_dict["player_hash"], + "player_hash": ws_message.player_hash, } except ValueError as e: return { diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index 42753e3fc14b312243a1e88ba97c62f15f5812d2..15856bb9b153f9d5f852e34e2a98f214bbad8703 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -1,5 +1,4 @@ import argparse -import dataclasses import json import logging import os @@ -21,12 +20,15 @@ from websockets.sync.client import connect from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.action import ActionType, InterActionData, Action -from cooperative_cuisine.game_server import CreateEnvironmentConfig +from cooperative_cuisine.game_server import ( + CreateEnvironmentConfig, + WebsocketMessage, + PlayerRequestType, +) from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer from cooperative_cuisine.pygame_2d_vis.game_colors import colors from cooperative_cuisine.state_representation import StateRepresentation from cooperative_cuisine.utils import ( - custom_asdict_factory, url_and_port_arguments, disable_websocket_logging_arguments, add_list_of_manager_ids_arguments, @@ -48,7 +50,9 @@ log = logging.getLogger(__name__) class PlayerKeySet: """Set of keyboard keys for controlling a player. - First four keys are for movement. Order: Down, Up, Left, Right. 5th key is for interacting with counters. 6th key ist for picking up things or dropping them. + First four keys are for movement. Order: Down, Up, Left, Right. + 5th key is for interacting with counters. + 6th key ist for picking up things or dropping them. """ def __init__( @@ -124,7 +128,7 @@ class PyGameGUI: self.participant_id = uuid.uuid4().hex - self.game_screen: pygame.Surface = None + self.game_screen: pygame.Surface | None = None self.running = True self.key_sets: list[PlayerKeySet] = [] @@ -186,6 +190,8 @@ class PyGameGUI: self.last_state: StateRepresentation + self.player_info = {"0": {"name": "0"}} + self.level_info = {"name": "Level", "recipe_graphs": []} self.last_level = False self.beeped_once = False self.all_completed_meals = [] @@ -248,7 +254,7 @@ class PyGameGUI: action = Action( current_player_name, - ActionType.MOVEMENT, + ActionType.MOVEMENT.value, move_vec, duration=self.time_delta, ) @@ -289,7 +295,7 @@ class PyGameGUI: action = Action( current_player_name, - ActionType.MOVEMENT, + ActionType.MOVEMENT.value, move_vec, duration=self.time_delta, ) @@ -1019,14 +1025,10 @@ class PyGameGUI: if self.CONNECT_WITH_STUDY_SERVER: self.send_level_done() - else: - self.stop_game_on_server("finished_button_pressed") self.disconnect_websockets() self.update_postgame_screen(self.last_state) self.update_screen_elements() - for el in self.last_completed_meals: - el.show() self.beeped_once = False def draw_game_screen_frame(self): @@ -1142,8 +1144,7 @@ class PyGameGUI: environment_config = file.read() num_players = 1 if tutorial else self.number_players - seed = 1234 - print("GUI MANAGER ID", self.manager_id) + seed = int(random.random() * 100000) creation_json = CreateEnvironmentConfig( manager_id=self.manager_id, number_players=num_players, @@ -1161,6 +1162,8 @@ class PyGameGUI: ) if env_info.status_code == 403: raise ValueError(f"Forbidden Request: {env_info.json()['detail']}") + elif env_info.status_code == 409: + print("CONFLICT") env_info = env_info.json() assert isinstance(env_info, dict), "Env info must be a dictionary" self.current_env_id = env_info["env_id"] @@ -1279,18 +1282,27 @@ class PyGameGUI: def get_game_connection(self, tutorial): if self.menu_state == MenuStates.ControllerTutorial: - self.player_info = requests.post( + answer = requests.post( f"{self.request_url}/connect_to_tutorial/{self.participant_id}" - ).json() - self.player_info = {self.player_info["player_id"]: self.player_info} - + ) + if answer.status_code == 200: + self.player_info = answer.json() + self.player_info = {self.player_info["player_id"]: self.player_info} + else: + self.menu_state = MenuStates.Start + log.warning("Could not connect to tutorial.") else: answer = requests.post( f"{self.request_url}/get_game_connection/{self.participant_id}" - ).json() - self.player_info = answer["player_info"] - self.level_info = answer["level_info"] - self.last_level = self.level_info["last_level"] + ) + if answer.status_code == 200: + answer_json = answer.json() + self.player_info = answer_json["player_info"] + self.level_info = answer_json["level_info"] + self.last_level = self.level_info["last_level"] + else: + log.warning("COULD NOT GET GAME CONNECTION") + self.menu_state = MenuStates.Start if tutorial: self.key_sets = self.setup_player_keys(["0"], 1, False) @@ -1354,11 +1366,13 @@ class PyGameGUI: if p < self.number_humans_to_be_added: # add player websockets websocket = connect(self.websocket_url + player_info["client_id"]) - websocket.send( - json.dumps( - {"type": "ready", "player_hash": player_info["player_hash"]} - ) - ) + message_dict = { + "type": PlayerRequestType.READY.value, + "action": None, + "player_hash": player_info["player_hash"], + } + ws_message = WebsocketMessage(**message_dict) + websocket.send(ws_message.json()) assert ( json.loads(websocket.recv())["status"] == 200 ), "not accepted player" @@ -1472,30 +1486,23 @@ class PyGameGUI: float(action.action_data[1]), ] - self.websockets[action.player].send( - json.dumps( - { - "type": "action", - "action": dataclasses.asdict( - action, dict_factory=custom_asdict_factory - ), - "player_hash": self.player_info[action.player]["player_hash"], - } - ) - ) + message_dict = { + "type": PlayerRequestType.ACTION.value, + "action": action, + "player_hash": self.player_info[action.player]["player_hash"], + } + ws_message = WebsocketMessage(**message_dict) + self.websockets[action.player].send(ws_message.json()) self.websockets[action.player].recv() def request_state(self): - self.websockets[self.state_player_id].send( - json.dumps( - { - "type": "get_state", - "player_hash": self.player_info[self.state_player_id][ - "player_hash" - ], - } - ) - ) + message_dict = { + "type": PlayerRequestType.GET_STATE.value, + "action": None, + "player_hash": self.player_info[self.state_player_id]["player_hash"], + } + ws_message = WebsocketMessage(**message_dict) + self.websockets[self.state_player_id].send(ws_message.json()) state = json.loads(self.websockets[self.state_player_id].recv()) return state @@ -1527,27 +1534,36 @@ class PyGameGUI: log.log(logging.INFO, "Started game, played bell sound") def start_study(self): - self.player_info = requests.post( + answer = requests.post( f"{self.request_url}/start_study/{self.participant_id}/{self.number_humans_to_be_added}" - ).json() - self.last_level = False + ) + print("START STUDY ANSWER", answer) + if answer.status_code == 200: + self.last_level = False + self.get_game_connection(tutorial=False) + else: + self.menu_state = MenuStates.Start + print( + "COULD NOT START STUDY; Response:", + answer.status_code, + answer.json()["detail"], + ) def send_level_done(self): _ = requests.post(f"{self.request_url}/level_done/{self.participant_id}").json() def button_continue_postgame_pressed(self): - if not self.CONNECT_WITH_STUDY_SERVER: + if self.CONNECT_WITH_STUDY_SERVER: + if not self.last_level: + self.get_game_connection(tutorial=False) + else: self.current_layout_idx += 1 + self.create_env_on_game_server(tutorial=False) if self.current_layout_idx == len(self.layout_file_paths) - 1: self.last_level = True else: log.debug(f"LEVEL: {self.layout_file_paths[self.current_layout_idx]}") - else: - if not self.last_level: - if self.CONNECT_WITH_STUDY_SERVER: - self.get_game_connection(tutorial=False) - else: - self.create_env_on_game_server(tutorial=False) + self.menu_state = MenuStates.PreGame def manage_button_event(self, event): @@ -1782,7 +1798,6 @@ class PyGameGUI: if self.CONNECT_WITH_STUDY_SERVER: self.send_tutorial_finished() self.start_study() - self.get_game_connection(tutorial=False) else: self.stop_game_on_server("tutorial_finished") diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py index 7db833228357f9bff8cecf4e089ca3ed1a15770a..887c686c22e9282163748b4f03ee27fdfe503ec0 100644 --- a/cooperative_cuisine/study_server.py +++ b/cooperative_cuisine/study_server.py @@ -20,17 +20,19 @@ import signal import subprocess from pathlib import Path from subprocess import Popen -from typing import Tuple, TypedDict +from typing import Tuple import requests import uvicorn import yaml -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.environment import EnvironmentConfig -from cooperative_cuisine.game_server import CreateEnvironmentConfig -from cooperative_cuisine.server_results import PlayerInfo +from cooperative_cuisine.game_server import CreateEnvironmentConfig, EnvironmentData +from cooperative_cuisine.server_results import PlayerInfo, CreateEnvResult from cooperative_cuisine.utils import ( url_and_port_arguments, add_list_of_manager_ids_arguments, @@ -51,85 +53,104 @@ USE_AAAMBOS_AGENT = False """Use the aaambos random agents instead of the simpler python script agents.""" -class LevelConfig(TypedDict): +class LevelConfig(BaseModel): name: str config_path: str layout_path: str item_info_path: str -class LevelInfo(TypedDict): +class LevelInfo(BaseModel): name: str last_level: bool - recipes: list[str] recipe_graphs: list[dict] -class StudyConfig(TypedDict): +class StudyConfig(BaseModel): levels: list[LevelConfig] num_players: int num_bots: int -class StudyState: - def __init__(self, study_config_path: str | Path, game_url, game_port): +class Study: + def __init__(self, study_config_path: str | Path, game_url: str, game_port: int): with open(study_config_path, "r") as file: env_config_f = file.read() self.study_config: StudyConfig = yaml.load( str(env_config_f), Loader=yaml.SafeLoader ) + """Configuration for the study which layouts, env_configs and item infos are used for the study levels.""" self.levels: list[LevelConfig] = self.study_config["levels"] + """List of level configs for each of the levels which the study runs through.""" self.current_level_idx: int = 0 - - self.participant_id_to_player_info = {} - self.player_ids = {} + """Counter of which level is currently run in the config.""" + self.participant_id_to_player_info: dict[str, PlayerInfo] = {} + """A dictionary which maps participants to player infos.""" self.num_connected_players: int = 0 + """Number of currently connected players.""" - self.current_running_env = None - self.next_level_env = None - self.players_done = {} - - self.use_aaambos_agent = USE_AAAMBOS_AGENT + self.current_running_env: CreateEnvResult | None = None + """Information about the current running environment.""" + self.participants_done: dict[str, bool] = {} + """A dictionary which saves which player has sent ready.""" + self.current_config: dict | None = None + """Save current environment config""" - self.websocket_url = f"ws://{game_url}:{game_port}/ws/player/" - print("WS:", self.websocket_url) - self.sub_processes = [] + self.use_aaambos_agent: bool = USE_AAAMBOS_AGENT + """Use aaambos-agents or simple python scripts.""" - self.current_item_info = None - self.current_config = None + """Use aaambos-agents or simple python scripts.""" + self.bot_websocket_url: str = f"ws://{game_url}:{game_port}/ws/player/" + """The websocket url for the bots to use.""" + self.sub_processes: list[Popen] = [] + """Save subprocesses of the bots to be able to kill them afterwards.""" @property - def study_done(self): + def study_done(self) -> bool: return self.current_level_idx >= len(self.levels) @property - def last_level(self): + def last_level(self) -> bool: return self.current_level_idx >= len(self.levels) - 1 @property - def is_full(self): + def is_full(self) -> bool: return ( len(self.participant_id_to_player_info) == self.study_config["num_players"] ) - def can_add_participant(self, num_participants: int) -> bool: + def can_add_participants(self, num_participants: int) -> bool: + """Checks whether the number of participants fit in this study. + + Args: + num_participants: Number of participants wished to be added. + + Returns: True of the participants fit in this study, False if not. + """ filled = ( self.num_connected_players + num_participants <= self.study_config["num_players"] ) return filled and not self.is_full - def create_env(self, level): + def create_env(self, level: LevelConfig) -> EnvironmentData: + """Creates/starts an environment on the game server, + given the configuration file paths specified in the level. + + Args: + level: LevelConfig which contains the paths to the env config, layout and item info files. + + Returns: EnvironmentData which contains information about the newly created environment. + Raises: ValueError if the gameserver returned a conflict, HTTPError with 500 if the game server crashes. + """ + item_info_path = expand_path(level["item_info_path"]) layout_path = expand_path(level["layout_path"]) config_path = expand_path(level["config_path"]) with open(item_info_path, "r") as file: item_info = file.read() - self.current_item_info: EnvironmentConfig = yaml.load( - item_info, Loader=yaml.Loader - ) with open(layout_path, "r") as file: layout = file.read() with open(config_path, "r") as file: @@ -138,7 +159,6 @@ class StudyState: environment_config, Loader=yaml.Loader ) seed = int(random.random() * 1000000) - print(seed) creation_json = CreateEnvironmentConfig( manager_id=study_manager.server_manager_id, number_players=self.study_config["num_players"] @@ -156,6 +176,11 @@ class StudyState: if env_info.status_code == 403: raise ValueError(f"Forbidden Request: {env_info.json()['detail']}") + elif env_info.status_code == 500: + raise HTTPException( + status_code=500, + detail=f"Game server crashed.", + ) env_info = env_info.json() player_info = env_info["player_info"] @@ -165,10 +190,13 @@ class StudyState: return env_info def start_level(self): - level = self.levels[self.current_level_idx] - self.current_running_env = self.create_env(level) + """Starts an environment based on the current level index.""" + self.current_running_env = self.create_env(self.levels[self.current_level_idx]) def next_level(self): + """Stops the last environment, starts the next one and + remaps the participants to the new player infos. + """ requests.post( f"{study_manager.game_server_url}/manage/stop_env/", json={ @@ -191,10 +219,16 @@ class StudyState: } self.participant_id_to_player_info[participant_id] = new_player_info - for key in self.players_done: - self.players_done[key] = False + for key in self.participants_done: + self.participants_done[key] = False def add_participant(self, participant_id: str, number_players: int): + """Adds a participant to the study, one participant can control multiple players. + + Args: + participant_id: The participant id for which to register the participant. + number_players: The number of players which the participant controls. + """ player_names = [ str(self.num_connected_players + i) for i in range(number_players) ] @@ -204,32 +238,52 @@ class StudyState: } self.participant_id_to_player_info[participant_id] = player_info self.num_connected_players += number_players - return player_info - def player_finished_level(self, participant_id): - self.players_done[participant_id] = True - if all(self.players_done.values()): + def participant_finished_level(self, participant_id: str): + """Signals the server if a player has finished a level. + If all participants finished the level, the next level is started.""" + self.participants_done[participant_id] = True + if all(self.participants_done.values()): self.next_level() - def get_connection(self, participant_id: str): - player_info = self.participant_id_to_player_info[participant_id] - current_level = self.levels[self.current_level_idx] - if self.current_config["meals"]["all"]: - recipes = ["all"] + def get_connection( + self, participant_id: str + ) -> Tuple[PlayerInfo | None, LevelInfo | None]: + """Get the assigned connections to the game server for a participant. + + Args: + participant_id: The participant id which requests the connections. + + Returns: The player info for the game server connections, level name and + information if the level is the last one and which recipes are possible in the level. + Raises: HTTPException(409) if the player is not found in the dictionary keys which saves the connections. + """ + if participant_id in self.participant_id_to_player_info.keys(): + player_info = self.participant_id_to_player_info[participant_id] + current_level = self.levels[self.current_level_idx] + level_info = LevelInfo( + name=current_level["name"], + last_level=self.last_level, + recipe_graphs=self.current_running_env["recipe_graphs"], + ) + return player_info, level_info else: - recipes = self.current_config["meals"]["list"] - level_info = LevelInfo( - name=current_level["name"], - last_level=self.last_level, - recipes=recipes, - recipe_graphs=self.current_running_env["recipe_graphs"], - ) - return player_info, level_info + raise HTTPException( + status_code=409, + detail=f"Participant not registered in this study.", + ) + + def create_and_connect_bot(self, player_id: str, player_info: PlayerInfo): + """Creates and connects a bot to the current environment. - def create_and_connect_bot(self, player_id, player_info): + Args: + player_id: player id of the player the bot controls. + player_info: Connection info for the bot. + """ player_hash = player_info["player_hash"] + ws_address = self.bot_websocket_url + player_info["client_id"] print( - f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"' + f'--general_plus="agent_websocket:{ws_address};player_hash:{player_hash};agent_id:{player_id}"' ) if self.use_aaambos_agent: sub = Popen( @@ -242,7 +296,7 @@ class StudyState: str(ROOT_DIR / "configs" / "agents" / "arch_config.yml"), "--run_config", str(ROOT_DIR / "configs" / "agents" / "run_config.yml"), - f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"', + f'--general_plus="agent_websocket:{ws_address};player_hash:{player_hash};agent_id:{player_id}"', f"--instance={player_hash}", ] ), @@ -254,7 +308,7 @@ class StudyState: [ "python", str(ROOT_DIR / "configs" / "agents" / "random_agent.py"), - f'--uri {self.websocket_url + player_info["client_id"]}', + f'--uri {self.bot_websocket_url + player_info["client_id"]}', f"--player_hash {player_hash}", f"--player_id {player_id}", ] @@ -264,6 +318,7 @@ class StudyState: self.sub_processes.append(sub) def kill_bots(self): + """Terminates the subprocesses of the bots.""" for sub in self.sub_processes: try: if self.use_aaambos_agent: @@ -285,26 +340,34 @@ class StudyState: class StudyManager: - def __init__(self): - self.game_host: str | None = None - self.game_port: str | None = None - self.game_server_url: str | None = None - self.server_manager_id: str | None = None - - self.running_studies: list[StudyState] = [] + """Class which manages different studies, their creation and connecting participants to them.""" - self.participant_id_to_study_map: dict[str, StudyState] = {} - self.running_envs: dict[str, Tuple[int, dict[str, PlayerInfo], list[str]]] = {} - self.current_free_envs = [] + def __init__(self): + self.game_host: str + """Host address of the game server where the studies are running their environments.""" + self.game_port: int + """Port of the game server where the studies are running their environments.""" + self.game_server_url: str + """Combined URL of the game server where the studies are running their environments.""" + self.server_manager_id: str + """Manager id of this manager which will be registered in the game server.""" + self.running_studies: list[Study] = [] + """List of currently running studies.""" + + self.participant_id_to_study_map: dict[str, Study] = {} + """Dict which maps participants to studies.""" self.running_tutorials: dict[ str, Tuple[int, dict[str, PlayerInfo], list[str]] ] = {} + """Dict which saves currently running tutorial envs, as these do not need advanced player management.""" self.study_config_path = ROOT_DIR / "configs" / "study" / "study_config.yml" + """Path to the configuration file for the studies.""" def create_study(self): - study = StudyState( + """Creates a study with the path of the config files and the connection to the game server.""" + study = Study( self.study_config_path, self.game_host, self.game_port, @@ -313,36 +376,89 @@ class StudyManager: self.running_studies.append(study) def add_participant(self, participant_id: str, number_players: int): - player_info = None + """Adds participants to a study. Creates a new study if all other + studies have not enough free player slots + Args: + participant_id: ID of the participant which wants to connect to a study. + number_players: The number of player the participant wants to connect. + + Raises: HTTPException(409) if the participants requests more players than can fit in a study. + """ + if not self.running_studies or all( - [not s.can_add_participant(number_players) for s in self.running_studies] + [not s.can_add_participants(number_players) for s in self.running_studies] ): self.create_study() for study in self.running_studies: - if study.can_add_participant(number_players): - player_info = study.add_participant(participant_id, number_players) + if study.can_add_participants(number_players): + study.add_participant(participant_id, number_players) self.participant_id_to_study_map[participant_id] = study - return player_info + return + raise HTTPException(status_code=409, detail="Too many participants to add.") def player_finished_level(self, participant_id: str): - assigned_study = self.participant_id_to_study_map[participant_id] - assigned_study.player_finished_level(participant_id) + """A participant signals the study manager that they finished a level. - def get_participant_game_connection(self, participant_id: str): - assigned_study = self.participant_id_to_study_map[participant_id] - player_info, level_info = assigned_study.get_connection(participant_id) - return player_info, level_info + Args: + participant_id: ID of the participant. - def set_game_server_url(self, game_host, game_port): + Raises: HTTPException(409) if this participant is not registered in any study. + + """ + if participant_id in self.participant_id_to_study_map.keys(): + assigned_study = self.participant_id_to_study_map[participant_id] + assigned_study.participant_finished_level(participant_id) + else: + raise HTTPException(status_code=409, detail="Participant not in any study.") + + def get_participant_game_connection( + self, participant_id: str + ) -> Tuple[PlayerInfo, LevelInfo]: + """Get the assigned connections to the game server for a participant. + + Args: + participant_id: ID of the participant. + + Returns: The player info for the game server connections, level name and + information if the level is the last one and which recipes are possible in the level. + Raises: HTTPException(409) if the player not registered in any study. + """ + if participant_id in self.participant_id_to_study_map.keys(): + assigned_study = self.participant_id_to_study_map[participant_id] + player_info, level_info = assigned_study.get_connection(participant_id) + return player_info, level_info + else: + raise HTTPException(status_code=409, detail="Participant not in any study.") + + def set_game_server_url(self, game_host: str, game_port: int): + """Set the game server host address, port and combined url. These values are set this way because + the fastapi requests act on top level of the python script. + + Args: + game_host: The game server host address. + game_port: The game server port. + """ self.game_host = game_host self.game_port = game_port self.game_server_url = f"http://{self.game_host}:{self.game_port}" def set_manager_id(self, manager_id: str): + """Set the manager id of the study server. This value is set this way because + the fastapi requests act on top level of the python script. + + Args: + manager_id: Manager ID for this study manager so that it matches in the game server. + """ self.server_manager_id = manager_id def set_study_config(self, study_config_path: str): + """Set the study config path of the study server. This value is set this way because + the fastapi requests act on top level of the python script. + + Args: + study_config_path: Path to the study config file for the studies. + """ # TODO validate study_config? self.study_config_path = study_config_path @@ -352,17 +468,40 @@ study_manager = StudyManager() @app.post("/start_study/{participant_id}/{number_players}") async def start_study(participant_id: str, number_players: int): - player_info = study_manager.add_participant(participant_id, number_players) - return player_info + """Request to start a study. + + Args: + participant_id: ID of the requesting participant. + number_players: Number of player the participant wants to add to a study. + + """ + log.debug(f"ADDING PLAYERS: {number_players}") + study_manager.add_participant(participant_id, number_players) @app.post("/level_done/{participant_id}") async def level_done(participant_id: str): - last_level = study_manager.player_finished_level(participant_id) + """Request to signal that a participant has finished a level. + For synchronizing level endings and starting a new level. + + Args: + participant_id: ID of the requesting participant. + """ + study_manager.player_finished_level(participant_id) @app.post("/get_game_connection/{participant_id}") -async def get_game_connection(participant_id: str): +async def get_game_connection( + participant_id: str, +) -> dict[str, dict[str, PlayerInfo] | LevelInfo]: + """Request to get the connection to the game server of a participant. + + Args: + participant_id: ID of the requesting participant. + + Returns: A dict containing the game server connection information and information about the current level. + + """ player_info, level_info = study_manager.get_participant_game_connection( participant_id ) @@ -370,7 +509,16 @@ async def get_game_connection(participant_id: str): @app.post("/connect_to_tutorial/{participant_id}") -async def want_to_play_tutorial(participant_id: str): +async def connect_to_tutorial(participant_id: str) -> JSONResponse: + """Request of a participant to start a tutorial env and connect to it. + + Args: + participant_id: ID of the requesting participant. + Returns: Player info which contains game server connection information. + Raises: + HTTPException(403) if the game server returns 403 + HTTPException(500) if the game server returns 500 + """ environment_config_path = ROOT_DIR / "configs" / "tutorial_env_config.yaml" layout_path = ROOT_DIR / "configs" / "layouts" / "tutorial.layout" item_info_path = ROOT_DIR / "configs" / "item_info.yaml" @@ -382,7 +530,6 @@ async def want_to_play_tutorial(participant_id: str): with open(environment_config_path, "r") as file: environment_config = file.read() - print("STUDY MANAGER ID", study_manager.server_manager_id) creation_json = CreateEnvironmentConfig( manager_id=study_manager.server_manager_id, number_players=1, @@ -396,17 +543,33 @@ async def want_to_play_tutorial(participant_id: str): env_info = requests.post( study_manager.game_server_url + "/manage/create_env/", json=creation_json ) - - if env_info.status_code == 403: - raise ValueError(f"Forbidden Request: {env_info.json()['detail']}") - env_info = env_info.json() - study_manager.running_tutorials[participant_id] = env_info - return env_info["player_info"]["0"] + match env_info.status_code: + case 200: + env_info = env_info.json() + study_manager.running_tutorials[participant_id] = env_info + return JSONResponse(content=env_info["player_info"]["0"]) + case 403: + raise HTTPException( + status_code=403, + detail=f"Forbidden Request: {env_info.json()['detail']}", + ) + case 500: + raise HTTPException( + status_code=500, + detail=f"Game server crashed: {env_info.json()['detail']}", + ) @app.post("/disconnect_from_tutorial/{participant_id}") -async def want_to_play_tutorial(participant_id: str): - requests.post( +async def disconnect_from_tutorial(participant_id: str): + """A participant disconnects from a tutorial environment, which is then stopped on the game server. + + Args: + participant_id: The participant which disconnects from the tutorial. + + Raises: HTTPException(503) if the game server returns some error. + """ + answer = requests.post( f"{study_manager.game_server_url}/manage/stop_env/", json={ "manager_id": study_manager.server_manager_id, @@ -414,6 +577,10 @@ async def want_to_play_tutorial(participant_id: str): "reason": "Finished tutorial", }, ) + if answer.status_code != 200: + raise HTTPException( + status_code=503, detail="Could not disconnect from tutorial" + ) def main(study_host, study_port, game_host, game_port, manager_ids, study_config_path): @@ -434,7 +601,8 @@ if __name__ == "__main__": parser = argparse.ArgumentParser( prog="Cooperative Cuisine Study Server", description="Study Server: Match Making, client pre and post managing.", - epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html", + epilog="For further information, " + "see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html", ) url_and_port_arguments( parser=parser, diff --git a/cooperative_cuisine/validation.py b/cooperative_cuisine/validation.py index 07c4f6617b45f87af38de69bd7f1423a8db9e4a2..ff14a88f70e2bf32d1108e2ba9025515c2e33dad 100644 --- a/cooperative_cuisine/validation.py +++ b/cooperative_cuisine/validation.py @@ -1,6 +1,5 @@ import os import warnings -from concurrent.futures import ThreadPoolExecutor from typing import TypedDict, Tuple, Iterator import networkx as nx @@ -25,10 +24,17 @@ class MealGraphDict(TypedDict): class Validation: - def __init__(self, meals, item_info, order_manager): + def __init__( + self, + meals: list[ItemInfo], + item_info: dict[str, ItemInfo], + order_manager: OrderManager, + do_validation: bool, + ): self.meals: list[ItemInfo] = meals self.item_info: dict[str, ItemInfo] = item_info self.order_manager: OrderManager = order_manager + self.do_validation: bool = do_validation @staticmethod def infer_recipe_graph(item_info) -> DiGraph: @@ -243,40 +249,46 @@ class Validation: return layout_requirements def validate_environment(self, counters: list[Counter]): - graph = self.infer_recipe_graph(self.item_info) - os.makedirs(ROOT_DIR / "generated", exist_ok=True) - nx.nx_agraph.to_agraph(graph).draw( - ROOT_DIR / "generated" / "recipe_graph.png", format="png", prog="dot" - ) + if self.do_validation: + graph = self.infer_recipe_graph(self.item_info) + os.makedirs(ROOT_DIR / "generated", exist_ok=True) + nx.nx_agraph.to_agraph(graph).draw( + ROOT_DIR / "generated" / "recipe_graph.png", format="png", prog="dot" + ) - expected = self.get_item_info_requirements() - present = self.get_layout_requirements(counters) - possible_meals = set(meal for meal in expected if expected[meal] <= present) - defined_meals = set(map(lambda i: i.name, self.meals)) + expected = self.get_item_info_requirements() + present = self.get_layout_requirements(counters) + possible_meals = set(meal for meal in expected if expected[meal] <= present) + defined_meals = set(map(lambda i: i.name, self.meals)) - # print(f"Ordered meals: {defined_meals}, Possible meals: {possible_meals}") - if len(defined_meals - possible_meals) > 0: - warnings.warn( - f"Ordered meals are not possible: {defined_meals - possible_meals}" - ) + # print(f"Ordered meals: {defined_meals}, Possible meals: {possible_meals}") + if len(defined_meals - possible_meals) > 0: + warnings.warn( + f"Ordered meals are not possible: {defined_meals - possible_meals}" + ) - meals_to_be_ordered = possible_meals.intersection(defined_meals) - return meals_to_be_ordered - # print("FINAL MEALS:", meals_to_be_ordered) + meals_to_be_ordered = possible_meals.intersection(defined_meals) + return meals_to_be_ordered + else: + return {m.name for m in self.meals} def get_recipe_graphs(self) -> list[MealGraphDict]: if not self.order_manager.available_meals: return [] os.makedirs(ROOT_DIR / "generated", exist_ok=True) - # time_start = time.time() - with ThreadPoolExecutor( - max_workers=len(self.order_manager.available_meals) - ) as executor: - graph_dicts = list( - executor.map( - self.get_meal_graph, self.order_manager.available_meals.values() - ) - ) - # print("DURATION", time.time() - time_start) - return graph_dicts + return [ + self.get_meal_graph(m) for m in self.order_manager.available_meals.values() + ] + + # # time_start = time.time() + # with ThreadPoolExecutor( + # max_workers=len(self.order_manager.available_meals) + # ) as executor: + # graph_dicts = list( + # executor.map( + # self.get_meal_graph, self.order_manager.available_meals.values() + # ) + # ) + # # print("DURATION", time.time() - time_start) + # return graph_dicts diff --git a/setup.py b/setup.py index bafd4edde0be2f57fef2f6ab094a643f2731b421..bd15680e42e22ed732f2ba8f9056aef547fb163c 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ requirements = [ "pytest>=3", "pyyaml>=6.0.1", "pygame-gui>=0.6.9", - "pydantic>=2.5.3", + "pydantic>=2.6.3", "fastapi>=0.109.2", "uvicorn>=0.27.0", "websockets>=12.0",