From d479117eab098d6aece6231c38809336113ca47d Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.de>
Date: Thu, 22 Feb 2024 15:45:30 +0100
Subject: [PATCH] Can play multiple players from one gui in study. Can add bots
 to study.

---
 overcooked_simulator/__main__.py              |  24 ++--
 overcooked_simulator/example_study_server.py  | 116 ++++++++++++++----
 .../study/environment_config.yaml             |   2 +-
 .../study/environment_config_dark.yaml        |   2 +-
 .../game_content/study/study_config.yaml      |   3 +-
 overcooked_simulator/gui_2d_vis/drawing.py    |   2 +-
 .../gui_2d_vis/overcooked_gui.py              |  97 ++++++++-------
 .../overcooked_environment.py                 |   2 +-
 8 files changed, 161 insertions(+), 87 deletions(-)

diff --git a/overcooked_simulator/__main__.py b/overcooked_simulator/__main__.py
index 8dcd0fba..339b8234 100644
--- a/overcooked_simulator/__main__.py
+++ b/overcooked_simulator/__main__.py
@@ -8,7 +8,7 @@ from overcooked_simulator.utils import (
     add_list_of_manager_ids_arguments,
 )
 
-USE_STUDY_SERVER = False
+USE_STUDY_SERVER = True
 
 
 def start_game_server(cli_args):
@@ -86,18 +86,18 @@ def main(cli_args=None):
 
         if USE_STUDY_SERVER:
             print("Start PyGame GUI:")
-            # pygame_gui_2 = Process(target=start_pygame_gui, args=(cli_args,))
-            # pygame_gui_2.start()
+            pygame_gui_2 = Process(target=start_pygame_gui, args=(cli_args,))
+            pygame_gui_2.start()
             #
-            # # print("Start PyGame GUI:")
-            # # pygame_gui_3 = Process(target=start_pygame_gui, args=(cli_args,))
-            # # pygame_gui_3.start()
-            # while (
-            #     pygame_gui.is_alive()
-            #     and pygame_gui_2.is_alive()
-            #     # and pygame_gui_3.is_alive()
-            # ):
-            #     time.sleep(1)
+            # print("Start PyGame GUI:")
+            # pygame_gui_3 = Process(target=start_pygame_gui, args=(cli_args,))
+            # pygame_gui_3.start()
+            while (
+                pygame_gui.is_alive()
+                and pygame_gui_2.is_alive()
+                # and pygame_gui_3.is_alive()
+            ):
+                time.sleep(1)
 
         while pygame_gui.is_alive():
             time.sleep(1)
diff --git a/overcooked_simulator/example_study_server.py b/overcooked_simulator/example_study_server.py
index b1fac496..5d24526d 100644
--- a/overcooked_simulator/example_study_server.py
+++ b/overcooked_simulator/example_study_server.py
@@ -14,7 +14,11 @@ The environment starts when all players connected.
 import argparse
 import asyncio
 import logging
+import os
+import signal
+import subprocess
 from pathlib import Path
+from subprocess import Popen
 from typing import Tuple, TypedDict
 
 import requests
@@ -43,8 +47,8 @@ server_manager_id = None
 HARDCODED_MANAGER_ID = "1234"
 
 
-running_tutorials: dict[str, Tuple[int, dict[str, PlayerInfo], list[str]]] = {}
 
+USE_AAAMBOS_AGENT = False
 
 class LevelConfig(TypedDict):
     name: str
@@ -56,6 +60,7 @@ class LevelConfig(TypedDict):
 class StudyConfig(TypedDict):
     levels: list[LevelConfig]
     num_players: int
+    num_bots: int
 
 
 class StudyState:
@@ -77,6 +82,11 @@ class StudyState:
         self.next_level_env = None
         self.players_done = {}
 
+        url = "localhost"
+        port = "8000"
+        self.websocket_url = f"ws://{url}:{port}/ws/player/"
+        self.sub_processes = []
+
     @property
     def study_done(self):
         return self.current_level_idx >= len(self.levels)
@@ -91,6 +101,9 @@ class StudyState:
             len(self.participant_id_to_player_info) == self.study_config["num_players"]
         )
 
+    def can_add_participant(self, participant_id: int) -> bool:
+        return len(self.participant_id_to_player_info) < self.study_config["num_players"]
+
     def create_env(self, level):
         with open(ROOT_DIR / "game_content" / level["item_info_path"], "r") as file:
             item_info = file.read()
@@ -103,7 +116,7 @@ class StudyState:
 
         creation_json = CreateEnvironmentConfig(
             manager_id=server_manager_id,
-            number_players=self.study_config["num_players"],
+            number_players=self.study_config["num_players"] + self.study_config["num_bots"],
             environment_settings={"all_player_can_pause_game": False},
             item_info_config=item_info,
             environment_config=environment_config,
@@ -119,6 +132,11 @@ class StudyState:
             raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
         env_info = env_info.json()
 
+        player_info = env_info["player_info"]
+        for idx, (player_id, player_info) in enumerate(player_info.items()):
+            if idx >= self.study_config["num_players"]:
+                self.create_and_connect_bot(player_id, player_info)
+
         return env_info
 
     def start(self):
@@ -139,21 +157,20 @@ class StudyState:
         if not self.study_done:
             level = self.levels[self.current_level_idx]
             self.current_running_env = self.create_env(level)
-            for participant_id, player_id in self.player_ids.items():
-                player_id = self.player_ids[participant_id]
+            for participant_id, player_info in self.participant_id_to_player_info.items():
+                new_player_info = {player_name: self.current_running_env["player_info"][player_name] for player_name in player_info.keys()}
                 self.participant_id_to_player_info[
                     participant_id
-                ] = self.current_running_env["player_info"][player_id]
+                ] = new_player_info
 
             for key in self.players_done:
                 self.players_done[key] = False
 
-    def add_participant(self, participant_id: str):
-        player_name = str(self.num_connected_players)
-        player_info = self.current_running_env["player_info"][player_name]
+    def add_participant(self, participant_id: str, number_players: int):
+        player_names = [str(self.num_connected_players+i) for i in range(number_players)]
+        player_info = {player_name: self.current_running_env["player_info"][player_name] for player_name in player_names}
         self.participant_id_to_player_info[participant_id] = player_info
-        self.player_ids[participant_id] = player_info["player_id"]
-        self.num_connected_players += 1
+        self.num_connected_players += number_players
         return player_info
 
     def player_finished_level(self, participant_id):
@@ -166,6 +183,62 @@ class StudyState:
         player_info = self.participant_id_to_player_info[participant_id]
         return player_info, self.last_level
 
+    def create_and_connect_bot(self, player_id, player_info):
+        player_hash = player_info["player_hash"]
+        print(
+            f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"'
+        )
+        if USE_AAAMBOS_AGENT:
+            sub = Popen(
+                " ".join(
+                    [
+                        "exec",
+                        "aaambos",
+                        "run",
+                        "--arch_config",
+                        str(ROOT_DIR / "game_content" / "agents" / "arch_config.yml"),
+                        "--run_config",
+                        str(ROOT_DIR / "game_content" / "agents" / "run_config.yml"),
+                        f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"',
+                        f"--instance={player_hash}",
+                    ]
+                ),
+                shell=True,
+            )
+        else:
+            sub = Popen(
+                " ".join(
+                    [
+                        "python",
+                        str(ROOT_DIR / "game_content" / "agents" / "random_agent.py"),
+                        f'--uri {self.websocket_url + player_info["client_id"]}',
+                        f"--player_hash {player_hash}",
+                        f"--player_id {player_id}",
+                    ]
+                ),
+                shell=True,
+            )
+        self.sub_processes.append(sub)
+
+    def kill_bots(self):
+        for sub in self.sub_processes:
+            try:
+                if USE_AAAMBOS_AGENT:
+                    pgrp = os.getpgid(sub.pid)
+                    os.killpg(pgrp, signal.SIGINT)
+                    subprocess.run(
+                        "kill $(ps aux | grep 'aaambos' | awk '{print $2}')", shell=True
+                    )
+                else:
+                    sub.kill()
+
+            except ProcessLookupError:
+                pass
+
+        self.sub_processes = []
+        for websocket in self.websockets.values():
+            websocket.close()
+
 
 class StudyManager:
     def __init__(self):
@@ -175,19 +248,21 @@ class StudyManager:
         self.running_envs: dict[str, Tuple[int, dict[str, PlayerInfo], list[str]]] = {}
         self.current_free_envs = []
 
+        self.running_tutorials: dict[str, Tuple[int, dict[str, PlayerInfo], list[str]]] = {}
+
     def create_study(self):
         study = StudyState(ROOT_DIR / "game_content" / "study" / "study_config.yaml")
         study.start()
         self.running_studies.append(study)
 
-    def add_participant(self, participant_id):
+    def add_participant(self, participant_id: str, number_players: int):
         player_info = None
         if all([s.is_full for s in self.running_studies]):
             self.create_study()
 
         for study in self.running_studies:
             if not study.is_full:
-                player_info = study.add_participant(participant_id)
+                player_info = study.add_participant(participant_id, number_players)
                 self.participant_id_to_study_map[participant_id] = study
         return player_info
 
@@ -204,10 +279,10 @@ class StudyManager:
 study_manager = StudyManager()
 
 
-@app.post("/start_study/{participant_id}")
-async def start_study(participant_id: str):
-    player_info = study_manager.add_participant(participant_id)
-    print()
+@app.post("/start_study/{participant_id}/{number_players}")
+async def start_study(participant_id: str, number_players: int):
+    print("NUMBER PLAYERS TO BE ADDED", number_players)
+    player_info = study_manager.add_participant(participant_id, number_players)
     return player_info
 
 
@@ -224,11 +299,6 @@ async def get_game_connection(participant_id: str):
     return {"player_info": player_info, "last_level": last_level}
 
 
-# @app.post("/finished_game/{participant_id}")
-# async def finished_game(participant_id: str):
-#     print(f"{participant_id} finished game.")
-#     return None
-
 
 @app.post("/connect_to_tutorial/{participant_id}")
 async def want_to_play_tutorial(participant_id: str):
@@ -261,7 +331,7 @@ async def want_to_play_tutorial(participant_id: str):
     if env_info.status_code == 403:
         raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
     env_info = env_info.json()
-    running_tutorials[participant_id] = env_info
+    study_manager.running_tutorials[participant_id] = env_info
     return env_info["player_info"]["0"]
 
 
@@ -271,7 +341,7 @@ async def want_to_play_tutorial(participant_id: str):
         f"{game_server_url}/manage/stop_env/",
         json={
             "manager_id": HARDCODED_MANAGER_ID,
-            "env_id": running_tutorials[participant_id]["env_id"],
+            "env_id": study_manager.running_tutorials[participant_id]["env_id"],
             "reason": "Finished tutorial",
         },
     )
diff --git a/overcooked_simulator/game_content/study/environment_config.yaml b/overcooked_simulator/game_content/study/environment_config.yaml
index 357c14af..00716059 100644
--- a/overcooked_simulator/game_content/study/environment_config.yaml
+++ b/overcooked_simulator/game_content/study/environment_config.yaml
@@ -5,7 +5,7 @@ plates:
   # range of seconds until the dirty plate arrives.
 
 game:
-  time_limit_seconds: 3
+  time_limit_seconds: 300
 
 meals:
   all: true
diff --git a/overcooked_simulator/game_content/study/environment_config_dark.yaml b/overcooked_simulator/game_content/study/environment_config_dark.yaml
index 7890ad12..2520eff2 100644
--- a/overcooked_simulator/game_content/study/environment_config_dark.yaml
+++ b/overcooked_simulator/game_content/study/environment_config_dark.yaml
@@ -5,7 +5,7 @@ plates:
   # range of seconds until the dirty plate arrives.
 
 game:
-  time_limit_seconds: 3
+  time_limit_seconds: 300
 
 meals:
   all: true
diff --git a/overcooked_simulator/game_content/study/study_config.yaml b/overcooked_simulator/game_content/study/study_config.yaml
index 966dc3ec..0f6b6e76 100644
--- a/overcooked_simulator/game_content/study/study_config.yaml
+++ b/overcooked_simulator/game_content/study/study_config.yaml
@@ -20,4 +20,5 @@ levels:
     name: "Level 4-2: Dark"
 
 
-num_players: 1
+num_players: 6
+num_bots: 2
diff --git a/overcooked_simulator/gui_2d_vis/drawing.py b/overcooked_simulator/gui_2d_vis/drawing.py
index fce70412..f382a9b9 100644
--- a/overcooked_simulator/gui_2d_vis/drawing.py
+++ b/overcooked_simulator/gui_2d_vis/drawing.py
@@ -137,7 +137,7 @@ class Visualizer:
             pygame.draw.circle(
                 screen,
                 col,
-                np.array(state["players"][idx]["pos"]) * grid_size + (grid_size // 2),
+                np.array(state["players"][int(idx)]["pos"]) * grid_size + (grid_size // 2),
                 (grid_size / 2),
             )
 
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index 6d43a3bf..417b6176 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -59,7 +59,7 @@ class PlayerKeySet:
         interact_key: pygame.key,
         pickup_key: pygame.key,
         switch_key: pygame.key,
-        players: list[int],
+        players: list[str],
         joystick: int,
     ):
         """Creates a player key set which contains information about which keyboard keys control the player.
@@ -82,13 +82,13 @@ class PlayerKeySet:
         self.interact_key: pygame.key = interact_key
         self.pickup_key: pygame.key = pickup_key
         self.switch_key: pygame.key = switch_key
-        self.controlled_players: list[int] = players
-        self.current_player: int = players[0] if players else 0
+        self.controlled_players: list[str] = players
+        self.current_player: str = players[0] if players else "0"
         self.current_idx = 0
         self.other_keyset: list[PlayerKeySet] = []
         self.joystick = joystick
 
-    def set_controlled_players(self, controlled_players: list[int]) -> None:
+    def set_controlled_players(self, controlled_players: list[str]) -> None:
         self.controlled_players = controlled_players
         self.current_player = self.controlled_players[0]
         self.current_idx = 0
@@ -180,12 +180,11 @@ class PyGameGUI:
 
         self.beeped_once = False
 
-    def setup_player_keys(self, number_players, number_key_sets=1, disjunct=False):
+    def setup_player_keys(self, players: list[str], number_key_sets=1, disjunct=False):
         # First four keys are for movement. Order: Down, Up, Left, Right.
         # 5th key is for interacting with counters.
         # 6th key ist for picking up things or dropping them.
         if number_key_sets:
-            players = list(range(number_players))
             key_set1 = PlayerKeySet(
                 move_keys=[pygame.K_a, pygame.K_d, pygame.K_w, pygame.K_s],
                 interact_key=pygame.K_f,
@@ -295,7 +294,7 @@ class PyGameGUI:
         for key_set in self.key_sets:
             current_player_name = str(key_set.current_player)
             if event.key == key_set.pickup_key and event.type == pygame.KEYDOWN:
-                action = Action(self.player_id, ActionType.PUT, "pickup")
+                action = Action(current_player_name, ActionType.PUT, "pickup")
                 self.send_action(action)
 
             if event.key == key_set.interact_key:
@@ -306,10 +305,10 @@ class PyGameGUI:
                     self.send_action(action)
                 elif event.type == pygame.KEYUP:
                     action = Action(
-                        self.player_id, ActionType.INTERACT, InterActionData.STOP
+                        current_player_name, ActionType.INTERACT, InterActionData.STOP
                     )
                     self.send_action(action)
-            if event.key == key_set.switch_key and not self.CONNECT_WITH_STUDY_SERVER:
+            if event.key == key_set.switch_key:
                 if event.type == pygame.KEYDOWN:
                     key_set.next_player()
 
@@ -796,6 +795,7 @@ class PyGameGUI:
             self.quit_button,
             self.fullscreen_button,
             self.player_selection_container,
+            self.bot_number_container,
             self.press_a_image,
         ]
 
@@ -825,6 +825,8 @@ class PyGameGUI:
             self.conclusion_label,
             self.quit_button,
             self.next_game_button,
+            self.finish_study_button
+
         ]
 
         self.end_screen_elements = [
@@ -837,7 +839,6 @@ class PyGameGUI:
             self.fullscreen_button,
             self.quit_button,
             self.retry_button,
-            self.finisseth_study_button,
             self.finished_button,
         ]
 
@@ -861,7 +862,8 @@ class PyGameGUI:
                 self.show_screen_elements(self.start_screen_elements)
 
                 if self.CONNECT_WITH_STUDY_SERVER:
-                    self.player_selection_container.hide()
+                    self.bot_number_container.hide()
+
 
                 self.update_selection_elements()
             case MenuStates.ControllerTutorial:
@@ -871,7 +873,6 @@ class PyGameGUI:
                     max_height=self.window_height * 0.3,
                     max_width=self.window_width * 0.3,
                 )
-                # self.set_window_size()
                 self.game_center = (
                     self.window_width - self.game_width / 2 - 20,
                     self.window_height - self.game_height / 2 - 20,
@@ -1078,17 +1079,16 @@ class PyGameGUI:
         self.current_env_id = env_info["env_id"]
         self.player_info = env_info["player_info"]
         if tutorial:
-            self.player_id = str(list(self.player_info.keys())[0])
+            self.player_ids = [str(list(self.player_info.keys())[0])]
+
+    def get_game_connection(self, tutorial):
 
-    def get_game_connection(self):
         if self.menu_state == MenuStates.ControllerTutorial:
             self.player_info = requests.post(
                 f"http://localhost:8080/connect_to_tutorial/{self.participant_id}"
             ).json()
-
-            self.key_sets[0].current_player = int(self.player_info["player_id"])
-            self.player_id = self.player_info["player_id"]
             self.player_info = {self.player_info["player_id"]: self.player_info}
+
         else:
             answer = requests.post(
                 f"http://localhost:8080/get_game_connection/{self.participant_id}"
@@ -1096,11 +1096,32 @@ class PyGameGUI:
             self.player_info = answer["player_info"]
             self.last_level = answer["last_level"]
 
-            print("LAST LEVEL", self.last_level)
+        if tutorial:
+            self.key_sets = self.setup_player_keys(["0"], 1, False)
+            self.vis.create_player_colors(1)
+        else:
+            self.number_players = (
+                self.number_humans_to_be_added + self.number_bots_to_be_added
+            )
 
-            self.key_sets[0].current_player = int(self.player_info["player_id"])
-            self.player_id = self.player_info["player_id"]
-            self.player_info = {self.player_info["player_id"]: self.player_info}
+            if self.split_players:
+                assert (
+                    self.number_humans_to_be_added > 1
+                ), "Not enough players for key configuration."
+            num_key_set = 2 if self.multiple_keysets else 1
+            self.key_sets = self.setup_player_keys(
+                list(self.player_info.keys()),
+                min(self.number_humans_to_be_added, num_key_set),
+                self.split_players,
+            )
+
+        # self.key_sets[0].current_player = int(self.player_info["player_id"])
+        # self.player_id = self.player_info["player_id"]
+        # self.player_info = {self.player_info["player_id"]: self.player_info}
+
+        # for i, k in enumerate(self.key_sets):
+        #     k.current_player = list(self.player_info.keys())[i]
+        self.player_ids = list(self.player_info.keys())
 
     def create_and_connect_bot(self, player_id, player_info):
         player_hash = player_info["player_hash"]
@@ -1143,6 +1164,7 @@ class PyGameGUI:
         for p, (player_id, player_info) in enumerate(self.player_info.items()):
             if p < self.number_humans_to_be_added:
                 # add player websockets
+                print(player_info)
                 websocket = connect(self.websocket_url + player_info["client_id"])
                 websocket.send(
                     json.dumps(
@@ -1162,27 +1184,10 @@ class PyGameGUI:
                 self.state_player_id = player_id
 
     def setup_game(self, tutorial=False):
-        if tutorial:
-            self.key_sets = self.setup_player_keys(1, 1, False)
-            self.vis.create_player_colors(1)
-        else:
-            self.number_players = (
-                self.number_humans_to_be_added + self.number_bots_to_be_added
-            )
 
-            if self.split_players:
-                assert (
-                    self.number_humans_to_be_added > 1
-                ), "Not enough players for key configuration."
-            num_key_set = 2 if self.multiple_keysets else 1
-            self.key_sets = self.setup_player_keys(
-                self.number_humans_to_be_added,
-                min(self.number_humans_to_be_added, num_key_set),
-                self.split_players,
-            )
 
         if self.CONNECT_WITH_STUDY_SERVER:
-            self.get_game_connection()
+            self.get_game_connection(tutorial)
         else:
             self.create_env_on_game_server(tutorial)
 
@@ -1216,7 +1221,6 @@ class PyGameGUI:
         if not self.CONNECT_WITH_STUDY_SERVER:
             self.stop_game("finished_button_pressed")
         self.menu_state = MenuStates.PostGame
-        self.reset_window_size()
         log.debug("Pressed finished button")
         self.update_screen_elements()
 
@@ -1284,7 +1288,8 @@ class PyGameGUI:
                 float(action.action_data[0]),
                 float(action.action_data[1]),
             ]
-        self.websockets[self.player_id].send(
+
+        self.websockets[action.player].send(
             json.dumps(
                 {
                     "type": "action",
@@ -1295,7 +1300,7 @@ class PyGameGUI:
                 }
             )
         )
-        self.websockets[self.player_id].recv()
+        self.websockets[action.player].recv()
 
     def request_state(self):
         self.websockets[self.state_player_id].send(
@@ -1340,16 +1345,14 @@ class PyGameGUI:
 
     def start_study(self):
         self.player_info = requests.post(
-            f"http://localhost:8080/start_study/{self.participant_id}"
+            f"http://localhost:8080/start_study/{self.participant_id}/{self.number_humans_to_be_added}"
         ).json()
         self.last_level = False
 
     def send_level_done(self):
-        answer = requests.post(
+        _ = requests.post(
             f"http://localhost:8080/level_done/{self.participant_id}"
         ).json()
-        # self.last_level = answer["last_level"]
-        # print("\nAT LAST LEVEL:", self.last_level, "\n")
 
     def button_continue_postgame_pressed(self):
         if not self.CONNECT_WITH_STUDY_SERVER:
@@ -1519,7 +1522,7 @@ class PyGameGUI:
 
                     # Press key instead of mouse button press
                     if (
-                        event.type == pygame.KEYDOWN and event.key == pygame.K_SPACE or (pygame.JOYBUTTONDOWN and (self.joysticks and self.joysticks[0].get_button(0)))
+                        pygame.JOYBUTTONDOWN and any([self.joysticks and self.joysticks[i].get_button(0) for i in range(len(self.joysticks))])
                     ):
                         match self.menu_state:
                             case MenuStates.Start:
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index 6a9c69a2..5fb8c791 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -57,7 +57,7 @@ from overcooked_simulator.utils import create_init_env_time, get_closest
 log = logging.getLogger(__name__)
 
 
-PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = False
+PREVENT_SQUEEZING_INTO_OTHER_PLAYERS = True
 
 
 class ActionType(Enum):
-- 
GitLab