diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py index cf6db571881e17173a23f042a6c57ce4814f76fd..7d5d3b5bb152cb57b20036cba8c4416a129c6b87 100644 --- a/cooperative_cuisine/pygame_2d_vis/gui.py +++ b/cooperative_cuisine/pygame_2d_vis/gui.py @@ -1536,6 +1536,7 @@ class PyGameGUI: answer = requests.post( f"{self.request_url}/start_study/{self.participant_id}/{self.number_humans_to_be_added}" ) + print("START STUDY ANSWER", answer) if answer.status_code == 200: self.last_level = False self.get_game_connection(tutorial=False) diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py index 31b9d3a0589daa7db7cb4e8a7816d678834e19c3..e139aab5210929f5d79ff60576bb301e78f87738 100644 --- a/cooperative_cuisine/study_server.py +++ b/cooperative_cuisine/study_server.py @@ -27,7 +27,7 @@ import uvicorn import yaml from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse -from typing_extensions import TypedDict +from pydantic import BaseModel from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.environment import EnvironmentConfig @@ -52,21 +52,21 @@ app = FastAPI() USE_AAAMBOS_AGENT = False -class LevelConfig(TypedDict): +class LevelConfig(BaseModel): name: str config_path: str layout_path: str item_info_path: str -class LevelInfo(TypedDict): +class LevelInfo(BaseModel): name: str last_level: bool recipes: list[str] recipe_graphs: list[dict] -class StudyConfig(TypedDict): +class StudyConfig(BaseModel): levels: list[LevelConfig] num_players: int num_bots: int @@ -212,7 +212,6 @@ class StudyState: } self.participant_id_to_player_info[participant_id] = player_info self.num_connected_players += number_players - return player_info def player_finished_level(self, participant_id): self.players_done[participant_id] = True @@ -333,9 +332,9 @@ class StudyManager: for study in self.running_studies: if study.can_add_participants(number_players): - player_info = study.add_participant(participant_id, number_players) + study.add_participant(participant_id, number_players) self.participant_id_to_study_map[participant_id] = study - return player_info + return raise HTTPException(status_code=409, detail="Could not add participant(s).") def player_finished_level(self, participant_id: str): @@ -374,24 +373,24 @@ study_manager = StudyManager() @app.post("/start_study/{participant_id}/{number_players}") -async def start_study(participant_id: str, number_players: int) -> JSONResponse: +async def start_study(participant_id: str, number_players: int): log.debug(f"ADDING PLAYERS: {number_players}") - player_info = study_manager.add_participant(participant_id, number_players) - return JSONResponse(content=player_info) + study_manager.add_participant(participant_id, number_players) @app.post("/level_done/{participant_id}") -async def level_done(participant_id: str) -> JSONResponse: +async def level_done(participant_id: str): study_manager.player_finished_level(participant_id) - return JSONResponse(content="Ok") @app.post("/get_game_connection/{participant_id}") -async def get_game_connection(participant_id: str) -> JSONResponse: +async def get_game_connection( + participant_id: str, +) -> dict[str, dict[str, PlayerInfo] | LevelInfo]: player_info, level_info = study_manager.get_participant_game_connection( participant_id ) - return JSONResponse(content={"player_info": player_info, "level_info": level_info}) + return {"player_info": player_info, "level_info": level_info} @app.post("/connect_to_tutorial/{participant_id}") @@ -438,7 +437,7 @@ async def connect_to_tutorial(participant_id: str) -> JSONResponse: @app.post("/disconnect_from_tutorial/{participant_id}") -async def disconnect_from_tutorial(participant_id: str) -> JSONResponse: +async def disconnect_from_tutorial(participant_id: str): answer = requests.post( f"{study_manager.game_server_url}/manage/stop_env/", json={ @@ -447,9 +446,7 @@ async def disconnect_from_tutorial(participant_id: str) -> JSONResponse: "reason": "Finished tutorial", }, ) - if answer.status_code == 200: - return JSONResponse(content="Ok") - else: + if answer.status_code != 200: raise HTTPException( status_code=403, detail="Could not disconnect from tutorial" )