diff --git a/cooperative_cuisine/configs/study/study_config.yaml b/cooperative_cuisine/configs/study/study_config.yaml index 0c2ef412c98b6497e7b70d3c355157557b928ddf..d09fb9929d1a256a4399cb7b459d454cb5abf807 100644 --- a/cooperative_cuisine/configs/study/study_config.yaml +++ b/cooperative_cuisine/configs/study/study_config.yaml @@ -52,3 +52,5 @@ levels: num_players: 1 num_bots: 0 + +study_log_path: USER_LOG_DIR/ENV_NAME/ \ No newline at end of file diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py index 27896e71b25d73424f49e90a8fe6aa6ddd5afd23..d20a46c9825f60050e1a4515ba0a8be8c801dbdc 100644 --- a/cooperative_cuisine/study_server.py +++ b/cooperative_cuisine/study_server.py @@ -13,6 +13,7 @@ The environment starts when all players connected. import argparse import asyncio +import json import logging import os import random @@ -26,7 +27,7 @@ from typing import Tuple, Any import requests import uvicorn import yaml -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from cooperative_cuisine import ROOT_DIR @@ -294,7 +295,7 @@ class Study: self.next_level() def get_connection( - self, participant_id: str + self, participant_id: str, participant_host: str ) -> Tuple[dict[str, PlayerInfo] | None, LevelInfo | None]: """Get the assigned connections to the game server for a participant. @@ -318,6 +319,22 @@ class Study: number_players=len(self.current_running_env["player_info"]), kitchen_size=self.current_running_env["kitchen_size"], ) + log_path = expand_path( + self.study_config["study_log_path"], + env_name=self.current_running_env["env_id"], + ) + os.makedirs(log_path, exist_ok=True) + with open(Path(log_path) / "study_log", "a") as log_file: + log_file.write( + json.dumps( + { + "env_id": self.current_running_env["env_id"], + "participant_ip": participant_host, + "level_info": level_info.dict(), + "player_info": player_info, + } + ) + ) return player_info, level_info else: raise HTTPException( @@ -466,7 +483,7 @@ class StudyManager: raise HTTPException(status_code=409, detail="Participant not in any study.") def get_participant_game_connection( - self, participant_id: str + self, participant_id: str, participant_host: str ) -> Tuple[dict[str, PlayerInfo], LevelInfo]: """Get the assigned connections to the game server for a participant. @@ -494,7 +511,10 @@ class StudyManager: if participant_id in self.participant_id_to_study_map.keys(): assigned_study = self.participant_id_to_study_map[participant_id] - player_info, level_info = assigned_study.get_connection(participant_id) + player_info, level_info = assigned_study.get_connection( + participant_id, participant_host + ) + return player_info, level_info else: raise HTTPException(status_code=409, detail="Participant not in any study.") @@ -622,6 +642,7 @@ async def level_done(participant_id: str): @app.post("/get_game_connection/{participant_id}") async def get_game_connection( participant_id: str, + request: Request, ) -> dict[str, dict[str, PlayerInfo] | LevelInfo]: """Request to get the connection to the game server of a participant. @@ -633,7 +654,7 @@ async def get_game_connection( """ player_info, level_info = study_manager.get_participant_game_connection( - participant_id + participant_id, request.client.host ) return {"player_info": player_info, "level_info": level_info}