From 816ca407ace2c4f713636d601d4dcb252a17057b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Thu, 1 Feb 2024 13:20:46 +0100
Subject: [PATCH] Add agent configurations and improve error handling

This commit introduces agent configuration files for the Overcooked game simulator and enhances error handling in the game server. Notably, it includes safeguards to prevent errors when player data or environment IDs do not exist. Also, it modifies the game GUI to handle player keys and websockets according to the number of humans and bots in the game.
---
 .../game_content/agents/arch_config.yml       |  23 +++
 .../game_content/agents/run_config.yml        |  15 ++
 overcooked_simulator/game_server.py           |  20 ++-
 .../gui_2d_vis/overcooked_gui.py              | 139 ++++++++++++++----
 4 files changed, 165 insertions(+), 32 deletions(-)
 create mode 100644 overcooked_simulator/game_content/agents/arch_config.yml
 create mode 100644 overcooked_simulator/game_content/agents/run_config.yml

diff --git a/overcooked_simulator/game_content/agents/arch_config.yml b/overcooked_simulator/game_content/agents/arch_config.yml
new file mode 100644
index 00000000..60d20c0d
--- /dev/null
+++ b/overcooked_simulator/game_content/agents/arch_config.yml
@@ -0,0 +1,23 @@
+concurrency: MultiProcessing
+
+communication:
+  communication_prefs:
+   - !name:ipaacar_com_service.communications.ipaacar_com.IPAACARInfo
+
+modules:
+  connection:
+    module_info: !name:cocosy_agent.modules.connection_module.ConnectionModule
+    mean_frequency_step: 2  # 2: every 0.5 seconds
+  working_memory:
+    module_info: !name:cocosy_agent.modules.working_memory_module.WorkingMemoryModule
+  subtask_selection:
+    module_info: !name:cocosy_agent.modules.random_subtask_module.RandomSubtaskModule
+  action_execution:
+    module_info: !name:cocosy_agent.modules.action_execution_module.ActionExecutionModule
+    mean_frequency_step: 10  # 2: every 0.5 seconds
+  gui:
+    module_info: !name:aaambos.std.guis.pysimplegui.pysimplegui_window.PySimpleGUIWindowModule
+    window_title: Counting GUI
+    topics_to_show: [["SubtaskDecision", "cocosy_agent.conventions.communication.SubtaskDecision", ["task_type"]], ["ActionControl", "cocosy_agent.conventions.communication.ActionControl", ["action_type"]]]
+  status_manager:
+    module_info: !name:aaambos.std.modules.module_status_manager.ModuleStatusManager
\ No newline at end of file
diff --git a/overcooked_simulator/game_content/agents/run_config.yml b/overcooked_simulator/game_content/agents/run_config.yml
new file mode 100644
index 00000000..f9c6cdf6
--- /dev/null
+++ b/overcooked_simulator/game_content/agents/run_config.yml
@@ -0,0 +1,15 @@
+general:
+  agent_name: cocosy_agent
+  instance: _dev
+  local_agent_directories: ~/aaambos_agents
+  plus:
+    agent_websocket: ws://localhost:8000:/ws/player/MY_CLIENT_ID
+    player_hash: abcdefghijklmnopqrstuvwxyz
+    agent_id: 1
+
+logging:
+    log_level_command_line: INFO
+
+supervisor:
+  run_time_manager_class: !name:aaambos.std.supervision.instruction_run_time_manager.instruction_run_time_manager.InstructionRunTimeManager
+
diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
index 95e8e18f..03679cc4 100644
--- a/overcooked_simulator/game_server.py
+++ b/overcooked_simulator/game_server.py
@@ -219,7 +219,9 @@ class EnvironmentHandler:
             self.envs[env_id].last_step_time = time.time_ns()
             self.envs[env_id].environment.reset_env_time()
 
-    def get_state(self, player_hash: str) -> str:  # -> StateRepresentation as json
+    def get_state(
+        self, player_hash: str
+    ) -> str | int:  # -> StateRepresentation as json
         """Get the current state representation of the environment for a player.
 
         Args:
@@ -236,6 +238,10 @@ class EnvironmentHandler:
             return self.envs[
                 self.player_data[player_hash].env_id
             ].environment.get_json_state()
+        if player_hash not in self.player_data:
+            return 1
+        if self.player_data[player_hash].env_id not in self.envs:
+            return 2
 
     def pause_env(self, manager_id: str, env_id: str, reason: str):
         """Pause the specified environment.
@@ -598,7 +604,17 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
                 }
 
             case PlayerRequestType.GET_STATE:
-                return environment_handler.get_state(message_dict["player_hash"])
+                state = environment_handler.get_state(message_dict["player_hash"])
+                if isinstance(state, int):
+                    return {
+                        "request_type": message_dict["type"],
+                        "status": 400,
+                        "msg": "env id of player not in running envs"
+                        if state == 2
+                        else "player hash unknown",
+                        "player_hash": None,
+                    }
+                return state
 
             case PlayerRequestType.ACTION:
                 assert (
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index 3260d1ef..b5db493c 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -2,9 +2,12 @@ import argparse
 import dataclasses
 import json
 import logging
+import os
 import random
+import signal
 import sys
 from enum import Enum
+from subprocess import Popen
 
 import numpy as np
 import pygame
@@ -97,7 +100,7 @@ class PyGameGUI:
         self.running = True
 
         self.reset_gui_values()
-        self.key_sets: list[PlayerKeySet] = self.setup_player_keys(1)
+        self.key_sets: list[PlayerKeySet] = []
 
         self.websocket_url = f"ws://{url}:{port}/ws/player/"
         self.websockets = {}
@@ -135,6 +138,8 @@ class PyGameGUI:
 
         self.vis = Visualizer(self.visualization_config)
 
+        self.sub_processes = []
+
     def get_window_sizes(self, state: dict):
         kitchen_width = state["kitchen"]["width"]
         kitchen_height = state["kitchen"]["height"]
@@ -177,29 +182,32 @@ class PyGameGUI:
         )
 
     def setup_player_keys(self, n=1, disjunct=False):
-        key_set1 = PlayerKeySet(
-            move_keys=[pygame.K_a, pygame.K_d, pygame.K_w, pygame.K_s],
-            interact_key=pygame.K_f,
-            pickup_key=pygame.K_e,
-            switch_key=pygame.K_SPACE,
-            players=list(range(self.number_humans_to_be_added)),
-        )
-        key_set2 = PlayerKeySet(
-            move_keys=[pygame.K_LEFT, pygame.K_RIGHT, pygame.K_UP, pygame.K_DOWN],
-            interact_key=pygame.K_i,
-            pickup_key=pygame.K_o,
-            switch_key=pygame.K_p,
-            players=list(range(self.number_humans_to_be_added)),
-        )
-        key_sets = [key_set1, key_set2]
-
-        if disjunct:
-            split_idx = int(np.ceil(self.number_humans_to_be_added / 2))
-            key_set1.set_controlled_players(list(range(0, split_idx)))
-            key_set2.set_controlled_players(
-                list(range(split_idx, self.number_humans_to_be_added))
+        if n:
+            key_set1 = PlayerKeySet(
+                move_keys=[pygame.K_a, pygame.K_d, pygame.K_w, pygame.K_s],
+                interact_key=pygame.K_f,
+                pickup_key=pygame.K_e,
+                switch_key=pygame.K_SPACE,
+                players=list(range(self.number_humans_to_be_added)),
+            )
+            key_set2 = PlayerKeySet(
+                move_keys=[pygame.K_LEFT, pygame.K_RIGHT, pygame.K_UP, pygame.K_DOWN],
+                interact_key=pygame.K_i,
+                pickup_key=pygame.K_o,
+                switch_key=pygame.K_p,
+                players=list(range(self.number_humans_to_be_added)),
             )
-        return key_sets[:n]
+            key_sets = [key_set1, key_set2]
+
+            if disjunct:
+                split_idx = int(np.ceil(self.number_humans_to_be_added / 2))
+                key_set1.set_controlled_players(list(range(0, split_idx)))
+                key_set2.set_controlled_players(
+                    list(range(split_idx, self.number_humans_to_be_added))
+                )
+            return key_sets[:n]
+        else:
+            return []
 
     def handle_keys(self):
         """Handles keyboard inputs. Sends action for the respective players. When a key is held down, every frame
@@ -697,18 +705,72 @@ class PyGameGUI:
         assert isinstance(env_info, dict), "Env info must be a dictionary"
         self.current_env_id = env_info["env_id"]
         self.player_info = env_info["player_info"]
-        for player_id, player_info in env_info["player_info"].items():
+
+        state = None
+        for p, (player_id, player_info) in enumerate(env_info["player_info"].items()):
+            if p < self.number_humans_to_be_added:
+                websocket = connect(self.websocket_url + player_info["client_id"])
+                websocket.send(
+                    json.dumps(
+                        {"type": "ready", "player_hash": player_info["player_hash"]}
+                    )
+                )
+                assert (
+                    json.loads(websocket.recv())["status"] == 200
+                ), "not accepted player"
+                self.websockets[player_id] = websocket
+            else:
+                player_hash = player_info["player_hash"]
+                print(
+                    f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"'
+                )
+                sub = Popen(
+                    " ".join(
+                        [
+                            "exec",
+                            "aaambos",
+                            "run",
+                            "--arch_config",
+                            str(
+                                ROOT_DIR / "game_content" / "agents" / "arch_config.yml"
+                            ),
+                            "--run_config",
+                            str(
+                                ROOT_DIR / "game_content" / "agents" / "run_config.yml"
+                            ),
+                            f'--general_plus="agent_websocket:{self.websocket_url + player_info["client_id"]};player_hash:{player_hash};agent_id:{player_id}"',
+                            f"--instance={player_hash}",
+                        ]
+                    ),
+                    shell=True,
+                )
+                self.sub_processes.append(sub)
+
+            if p + 1 == self.number_humans_to_be_added:
+                self.state_player_id = player_id
+                websocket.send(
+                    json.dumps(
+                        {"type": "get_state", "player_hash": player_info["player_hash"]}
+                    )
+                )
+                state = json.loads(websocket.recv())
+
+        if not self.number_humans_to_be_added:
+            player_id = "0"
+            player_info = env_info["player_info"][player_id]
             websocket = connect(self.websocket_url + player_info["client_id"])
             websocket.send(
                 json.dumps({"type": "ready", "player_hash": player_info["player_hash"]})
             )
             assert json.loads(websocket.recv())["status"] == 200, "not accepted player"
             self.websockets[player_id] = websocket
-        self.state_player_id = player_id
-        websocket.send(
-            json.dumps({"type": "get_state", "player_hash": player_info["player_hash"]})
-        )
-        state = json.loads(websocket.recv())
+            self.state_player_id = player_id
+            websocket.send(
+                json.dumps(
+                    {"type": "get_state", "player_hash": player_info["player_hash"]}
+                )
+            )
+            state = json.loads(websocket.recv())
 
         (
             self.window_width,
@@ -732,7 +794,7 @@ class PyGameGUI:
             ), "Not enough players for key configuration."
         num_key_set = 2 if self.multiple_keysets else 1
         self.key_sets = self.setup_player_keys(
-            max(self.number_players, num_key_set), self.split_players
+            min(self.number_humans_to_be_added, num_key_set), self.split_players
         )
 
         self.setup_environment()
@@ -849,6 +911,18 @@ class PyGameGUI:
         return state
 
     def disconnect_websockets(self):
+        for sub in self.sub_processes:
+            try:
+                sub.kill()
+                pgrp = os.getpgid(sub.pid)
+                os.killpg(pgrp, signal.SIGINT)
+                # subprocess.run(
+                #     "kill $(ps aux | grep 'aaambos' | awk '{print $2}')", shell=True
+                # )
+            except ProcessLookupError:
+                pass
+
+        self.sub_processes = []
         for websocket in self.websockets.values():
             websocket.close()
 
@@ -882,6 +956,11 @@ class PyGameGUI:
                     if event.type == pygame_gui.UI_BUTTON_PRESSED:
                         match event.ui_element:
                             case self.start_button:
+                                if not (
+                                    self.number_humans_to_be_added
+                                    + self.number_bots_to_be_added
+                                ):
+                                    continue
                                 self.start_button_press()
                             case self.back_button:
                                 self.back_button_press()
-- 
GitLab