From 19a9fc302f5ba235e22adf3a1680e45ee98cee76 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Thu, 29 Feb 2024 17:29:16 +0100
Subject: [PATCH] Update configurations, path handling, and study server logic

Made updates across multiple files targeting the usage and setting of configuration files. Several path settings have been altered, and the logic in the study server has been revised. Changes were also made in the file imports and other operations related to configurations and paths. Additionally, a validation check is added to warn if a player with the same name already exists in the environment.
---
 cooperative_cuisine/__init__.py               |  1 +
 cooperative_cuisine/__main__.py               |  3 ++
 .../configs/study/level1/level1_config.yaml   |  2 +-
 .../configs/study/study_config.yaml           | 22 ++++-----
 cooperative_cuisine/environment.py            | 16 +++----
 cooperative_cuisine/recording.py              | 19 ++------
 cooperative_cuisine/study_server.py           | 46 ++++++++++++-------
 cooperative_cuisine/utils.py                  | 26 ++++++++++-
 8 files changed, 79 insertions(+), 56 deletions(-)

diff --git a/cooperative_cuisine/__init__.py b/cooperative_cuisine/__init__.py
index b733be00..23cdcc74 100644
--- a/cooperative_cuisine/__init__.py
+++ b/cooperative_cuisine/__init__.py
@@ -371,6 +371,7 @@ websockets,
 - the **orders**, how to sample incoming orders and their attributes,
 - the **player**/agent, that interacts in the environment,
 - the **pygame 2d visualization**, GUI, drawing, and video generation,
+- the **recipe** validation and graph generation,
 - the **recording**, via hooks, actions, environment configs, states, etc. can be recorded in files,
 - the **scores**, via hooks, events can affect the scores,
 - type hints are defined in **state representation** for the json state and **server results** for the data returned by
diff --git a/cooperative_cuisine/__main__.py b/cooperative_cuisine/__main__.py
index 6c067da2..257ff656 100644
--- a/cooperative_cuisine/__main__.py
+++ b/cooperative_cuisine/__main__.py
@@ -6,6 +6,7 @@ from cooperative_cuisine.utils import (
     url_and_port_arguments,
     disable_websocket_logging_arguments,
     add_list_of_manager_ids_arguments,
+    add_study_arguments,
 )
 
 USE_STUDY_SERVER = True
@@ -26,6 +27,7 @@ def start_study_server(cli_args):
         game_host=cli_args.game_url,
         game_port=cli_args.game_port,
         manager_ids=cli_args.manager_ids,
+        study_config_path=cli_args.study_config,
     )
 
 
@@ -53,6 +55,7 @@ def main(cli_args=None):
     url_and_port_arguments(parser)
     disable_websocket_logging_arguments(parser)
     add_list_of_manager_ids_arguments(parser)
+    add_study_arguments(parser)
 
     cli_args = parser.parse_args()
 
diff --git a/cooperative_cuisine/configs/study/level1/level1_config.yaml b/cooperative_cuisine/configs/study/level1/level1_config.yaml
index 1751bfa0..6732fa29 100644
--- a/cooperative_cuisine/configs/study/level1/level1_config.yaml
+++ b/cooperative_cuisine/configs/study/level1/level1_config.yaml
@@ -9,7 +9,7 @@ game:
   undo_dispenser_pickup: true
 
 meals:
-  all: true
+  all: false
   # if all: false -> only orders for these meals are generated
   # TODO: what if this list is empty?
   list:
diff --git a/cooperative_cuisine/configs/study/study_config.yaml b/cooperative_cuisine/configs/study/study_config.yaml
index 0b6a0bb6..eb241e93 100644
--- a/cooperative_cuisine/configs/study/study_config.yaml
+++ b/cooperative_cuisine/configs/study/study_config.yaml
@@ -1,21 +1,17 @@
-# Config paths are relative to configs folder.
-# Layout files are relative to layouts folder.
-
-
 levels:
-  - config_path: study/level1/level1_config.yaml
-    layout_path: basic.layout
-    item_info_path: study/level1/level1_item_info.yaml
+  - config_path: STUDY_DIR/level1/level1_config.yaml
+    layout_path: LAYOUTS_DIR/overcooked-1/1-1-far-apart.layout
+    item_info_path: STUDY_DIR/level1/level1_item_info.yaml
     name: "Level 1-1: Far Apart"
 
-  - config_path: environment_config.yaml
-    layout_path: basic.layout
-    item_info_path: item_info.yaml
+  - config_path: CONFIGS_DIR/environment_config.yaml
+    layout_path: LAYOUTS_DIR/basic.layout
+    item_info_path: CONFIGS_DIR/item_info.yaml
     name: "Basic"
 
-  - config_path: study/level2/level2_config.yaml
-    layout_path: overcooked-1/1-4-bottleneck.layout
-    item_info_path: study/level2/level2_item_info.yaml
+  - config_path: STUDY_DIR/level2/level2_config.yaml
+    layout_path: LAYOUTS_DIR/overcooked-1/1-4-bottleneck.layout
+    item_info_path: STUDY_DIR/level2/level2_item_info.yaml
     name: "Level 1-4: Bottleneck"
 
 
diff --git a/cooperative_cuisine/environment.py b/cooperative_cuisine/environment.py
index 909db81e..169565e0 100644
--- a/cooperative_cuisine/environment.py
+++ b/cooperative_cuisine/environment.py
@@ -167,6 +167,8 @@ class Environment:
                 env_config = file.read()
             with open(layout_config, "r") as layout_file:
                 layout_config = layout_file.read()
+            with open(item_info, "r") as file:
+                item_info = file.read()
 
         self.environment_config: EnvironmentConfig = yaml.load(
             env_config, Loader=yaml.Loader
@@ -192,7 +194,7 @@ class Environment:
 
         self.item_info: dict[str, ItemInfo] = self.load_item_info(item_info)
         """The loaded item info dict. Keys are the item names."""
-        self.hook(ITEM_INFO_LOADED, item_info=item_info, as_files=as_files)
+        self.hook(ITEM_INFO_LOADED, item_info=item_info)
 
         # self.validate_item_info()
         if self.environment_config["meals"]["all"]:
@@ -350,13 +352,10 @@ class Environment:
         Utility method to pass a reference to the serving window."""
         return self.env_time
 
-    def load_item_info(self, data) -> dict[str, ItemInfo]:
+    def load_item_info(self, item_info) -> dict[str, ItemInfo]:
         """Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos."""
-        if self.as_files:
-            with open(data, "r") as file:
-                data = file.read()
-        self.hook(ITEM_INFO_CONFIG, item_info_config=data)
-        item_lookup = yaml.safe_load(data)
+        self.hook(ITEM_INFO_CONFIG, item_info_config=item_info)
+        item_lookup = yaml.safe_load(item_info)
         for item_name in item_lookup:
             item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])
 
@@ -408,7 +407,8 @@ class Environment:
             player_name: The id/name of the player to reference actions and in the state.
             pos: The optional init position of the player.
         """
-        # TODO check if the player name already exists in the environment and do not overwrite player.
+        if player_name in self.players:
+            raise ValueError(f"Player {player_name} already exists.")
         log.debug(f"Add player {player_name} to the game")
         player = Player(
             player_name,
diff --git a/cooperative_cuisine/recording.py b/cooperative_cuisine/recording.py
index 79f98fb9..01f0de86 100644
--- a/cooperative_cuisine/recording.py
+++ b/cooperative_cuisine/recording.py
@@ -46,12 +46,9 @@ import os
 import traceback
 from pathlib import Path
 
-import platformdirs
-
-from cooperative_cuisine import ROOT_DIR
 from cooperative_cuisine.environment import Environment
 from cooperative_cuisine.hooks import HookCallbackClass
-from cooperative_cuisine.utils import NumpyAndDataclassEncoder
+from cooperative_cuisine.utils import NumpyAndDataclassEncoder, expand_path
 
 log = logging.getLogger(__name__)
 
@@ -80,18 +77,8 @@ class FileRecorder(HookCallbackClass):
     ):
         super().__init__(name, env, **kwargs)
         self.add_hook_ref = add_hook_ref
-        log_path = log_path.replace("ENV_NAME", env.env_name).replace(
-            "LOG_RECORD_NAME", name
-        )
-        if log_path.startswith("USER_LOG_DIR/"):
-            log_path = (
-                Path(platformdirs.user_log_dir("cooperative_cuisine"))
-                / log_path[len("USER_LOG_DIR/") :]
-            )
-        elif log_path.startswith("ROOT_DIR/"):
-            log_path = ROOT_DIR / log_path[len("ROOT_DIR/") :]
-        else:
-            log_path = Path(log_path)
+        log_path = log_path.replace("LOG_RECORD_NAME", name)
+        log_path = Path(expand_path(log_path, env_name=env.env_name))
         self.log_path = log_path
         log.info(f"Recorder record for {name} in file://{log_path}")
         os.makedirs(log_path.parent, exist_ok=True)
diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py
index 3319cdde..68a3fee4 100644
--- a/cooperative_cuisine/study_server.py
+++ b/cooperative_cuisine/study_server.py
@@ -34,6 +34,8 @@ from cooperative_cuisine.server_results import PlayerInfo
 from cooperative_cuisine.utils import (
     url_and_port_arguments,
     add_list_of_manager_ids_arguments,
+    expand_path,
+    add_study_arguments,
 )
 
 NUMBER_PLAYER_PER_ENV = 2
@@ -87,12 +89,15 @@ class StudyState:
         self.next_level_env = None
         self.players_done = {}
 
-        self.USE_AAAMBOS_AGENT = False
+        self.use_aaambos_agent = False
 
         self.websocket_url = f"ws://{game_url}:{game_port}/ws/player/"
         print("WS:", self.websocket_url)
         self.sub_processes = []
 
+        self.current_item_info = None
+        self.current_config = None
+
     @property
     def study_done(self):
         return self.current_level_idx >= len(self.levels)
@@ -115,14 +120,18 @@ class StudyState:
         return filled and not self.is_full
 
     def create_env(self, level):
-        with open(ROOT_DIR / "configs" / level["item_info_path"], "r") as file:
+        item_info_path = expand_path(level["item_info_path"])
+        layout_path = expand_path(level["layout_path"])
+        config_path = expand_path(level["config_path"])
+
+        with open(item_info_path, "r") as file:
             item_info = file.read()
             self.current_item_info: EnvironmentConfig = yaml.load(
                 item_info, Loader=yaml.Loader
             )
-        with open(ROOT_DIR / "configs" / "layouts" / level["layout_path"], "r") as file:
+        with open(layout_path, "r") as file:
             layout = file.read()
-        with open(ROOT_DIR / "configs" / level["config_path"], "r") as file:
+        with open(config_path, "r") as file:
             environment_config = file.read()
             self.current_config: EnvironmentConfig = yaml.load(
                 environment_config, Loader=yaml.Loader
@@ -154,7 +163,7 @@ class StudyState:
                 self.create_and_connect_bot(player_id, player_info)
         return env_info
 
-    def start(self):
+    def start_level(self):
         level = self.levels[self.current_level_idx]
         self.current_running_env = self.create_env(level)
 
@@ -170,8 +179,7 @@ class StudyState:
 
         self.current_level_idx += 1
         if not self.study_done:
-            level = self.levels[self.current_level_idx]
-            self.current_running_env = self.create_env(level)
+            self.start_level()
             for (
                 participant_id,
                 player_info,
@@ -199,8 +207,7 @@ class StudyState:
 
     def player_finished_level(self, participant_id):
         self.players_done[participant_id] = True
-        level_done = all(self.players_done.values())
-        if level_done:
+        if all(self.players_done.values()):
             self.next_level()
 
     def get_connection(self, participant_id: str):
@@ -223,7 +230,7 @@ class StudyState:
         print(
             f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"'
         )
-        if self.USE_AAAMBOS_AGENT:
+        if self.use_aaambos_agent:
             sub = Popen(
                 " ".join(
                     [
@@ -258,7 +265,7 @@ class StudyState:
     def kill_bots(self):
         for sub in self.sub_processes:
             try:
-                if self.USE_AAAMBOS_AGENT:
+                if self.use_aaambos_agent:
                     pgrp = os.getpgid(sub.pid)
                     os.killpg(pgrp, signal.SIGINT)
                     subprocess.run(
@@ -271,8 +278,6 @@ class StudyState:
                 pass
 
         self.sub_processes = []
-        for websocket in self.websockets.values():
-            websocket.close()
 
     def __repr__(self):
         return f"Study({self.current_running_env['env_id']})"
@@ -295,13 +300,15 @@ class StudyManager:
             str, Tuple[int, dict[str, PlayerInfo], list[str]]
         ] = {}
 
+        self.study_config_path = ROOT_DIR / "configs" / "study" / "study_config.yml"
+
     def create_study(self):
         study = StudyState(
-            ROOT_DIR / "configs" / "study" / "study_config.yaml",
+            self.study_config_path,
             self.game_host,
             self.game_port,
         )
-        study.start()
+        study.start_level()
         self.running_studies.append(study)
 
     def add_participant(self, participant_id: str, number_players: int):
@@ -334,6 +341,10 @@ class StudyManager:
     def set_manager_id(self, manager_id: str):
         self.server_manager_id = manager_id
 
+    def set_study_config(self, study_config_path: str):
+        # TODO validate study_config?
+        self.study_config_path = study_config_path
+
 
 study_manager = StudyManager()
 
@@ -404,9 +415,10 @@ async def want_to_play_tutorial(participant_id: str):
     )
 
 
-def main(study_host, study_port, game_host, game_port, manager_ids):
+def main(study_host, study_port, game_host, game_port, manager_ids, study_config_path):
     study_manager.set_game_server_url(game_host=game_host, game_port=game_port)
     study_manager.set_manager_id(manager_id=manager_ids[0])
+    study_manager.set_study_config(study_config_path=study_config_path)
 
     print(
         f"Use {study_manager.server_manager_id=} for game_server_url=http://{game_host}:{game_port}"
@@ -430,6 +442,7 @@ if __name__ == "__main__":
         default_game_port=8000,
     )
     add_list_of_manager_ids_arguments(parser=parser)
+    add_study_arguments(parser=parser)
     args = parser.parse_args()
 
     game_server_url = f"https://{args.game_url}:{args.game_port}"
@@ -439,4 +452,5 @@ if __name__ == "__main__":
         game_host=args.game_url,
         game_port=args.game_port,
         manager_ids=args.manager_ids,
+        study_config_path=args.study_config,
     )
diff --git a/cooperative_cuisine/utils.py b/cooperative_cuisine/utils.py
index 2d0a7dc3..7d57d5c6 100644
--- a/cooperative_cuisine/utils.py
+++ b/cooperative_cuisine/utils.py
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING
 
 import numpy as np
 import numpy.typing as npt
+import platformdirs
 from scipy.spatial import distance_matrix
 
 from cooperative_cuisine import ROOT_DIR
@@ -27,6 +28,17 @@ from cooperative_cuisine.player import Player
 DEFAULT_SERVER_URL = "localhost"
 
 
+def expand_path(path: str, env_name: str = "") -> str:
+    return os.path.expanduser(
+        path.replace("ROOT_DIR", str(ROOT_DIR))
+        .replace("ENV_NAME", env_name)
+        .replace("USER_LOG_DIR", platformdirs.user_log_dir("cooperative_cuisine"))
+        .replace("LAYOUTS_DIR", str(ROOT_DIR / "configs" / "layouts"))
+        .replace("STUDY_DIR", str(ROOT_DIR / "configs" / "study"))
+        .replace("CONFIGS_DIR", str(ROOT_DIR / "configs"))
+    )
+
+
 @dataclasses.dataclass
 class VectorStateGenerationData:
     grid_base_array: npt.NDArray[npt.NDArray[float]]
@@ -179,8 +191,9 @@ def setup_logging(enable_websocket_logging=False):
 def url_and_port_arguments(
     parser, server_name="game server", default_study_port=8080, default_game_port=8000
 ):
+    # TODO follow cli args standards: https://askubuntu.com/questions/711702/when-are-command-options-prefixed-with-two-hyphens
     parser.add_argument(
-        "-study-url",
+        "-study",
         "--study-url",
         "--study-host",
         type=str,
@@ -195,7 +208,7 @@ def url_and_port_arguments(
         help=f"Port number for the {server_name}",
     )
     parser.add_argument(
-        "-game-url",
+        "-game",
         "--game-url",
         "--game-host",
         type=str,
@@ -228,6 +241,15 @@ def add_list_of_manager_ids_arguments(parser):
     )
 
 
+def add_study_arguments(parser):
+    parser.add_argument(
+        "--study-config",
+        type=str,
+        default=ROOT_DIR / "configs" / "study" / "study_config.yaml",
+        help="The config of the study.",
+    )
+
+
 class NumpyAndDataclassEncoder(json.JSONEncoder):
     """Special json encoder for numpy types"""
 
-- 
GitLab