From 37c127e1d5419307a216fb691c38a65d6dca960f Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Wed, 14 Feb 2024 15:56:13 +0100
Subject: [PATCH] Implement study server

Several GUI instances are managed by a study server for matchmaking.
Needs code adjustments to work (change static var)
---
 overcooked_simulator/example_study_server.py  | 136 ++++++++++++++++++
 overcooked_simulator/game_server.py           |  22 ++-
 overcooked_simulator/gui_2d_vis/drawing.py    |   4 +-
 .../gui_2d_vis/overcooked_gui.py              | 125 ++++++++--------
 overcooked_simulator/utils.py                 |   8 +-
 5 files changed, 222 insertions(+), 73 deletions(-)
 create mode 100644 overcooked_simulator/example_study_server.py

diff --git a/overcooked_simulator/example_study_server.py b/overcooked_simulator/example_study_server.py
new file mode 100644
index 00000000..778135f4
--- /dev/null
+++ b/overcooked_simulator/example_study_server.py
@@ -0,0 +1,136 @@
+"""
+# Usage
+- Set `CONNECT_WITH_STUDY_SERVER` in overcooked_gui.py to True.
+- Run this script. Copy the manager id that is printed
+- Run the game_server.py script with the manager id copied from the terminal
+```
+python game_server.py --manager_ids COPIED_UUID
+```
+- Run 2 overcooked_gui.py scripts in different terminals. For more players change `NUMBER_PLAYER_PER_ENV` and start more guis.
+
+The environment starts when all players connected.
+"""
+
+import argparse
+import asyncio
+import logging
+from typing import Tuple
+
+import requests
+import uvicorn
+from fastapi import FastAPI
+
+from overcooked_simulator import ROOT_DIR
+from overcooked_simulator.game_server import CreateEnvironmentConfig
+from overcooked_simulator.server_results import PlayerInfo
+from overcooked_simulator.utils import (
+    url_and_port_arguments,
+    add_list_of_manager_ids_arguments,
+)
+
+NUMBER_PLAYER_PER_ENV = 2
+
+log = logging.getLogger(__name__)
+
+
+app = FastAPI()
+
+game_server_url = "localhost:8000"
+server_manager_id = None
+
+
+# @app.get("/")
+# async def root(response_class=HTMLResponse):
+#     return """
+#     <html>
+#         <head>
+#             <title>Overcooked Game</title>
+#         </head>
+#         <body>
+#             <h1>Start Game!</h1>
+#             <button type="button">Click Me!</button>
+#         </body>
+#     </html>
+#     """
+
+running_envs: dict[str, Tuple[int, dict[str, PlayerInfo], list[str]]] = {}
+current_free_envs = []
+
+
+@app.post("/connect_to_game/{request_id}")
+async def want_to_play(request_id: str):
+    global current_free_envs
+    # TODO based on study desing / internal state of request id current state (which level to play)
+    if current_free_envs:
+        current_free_env = current_free_envs.pop()
+
+        running_envs[current_free_env][2].append(request_id)
+        new_running_env = (
+            running_envs[current_free_env][0] + 1,
+            running_envs[current_free_env][1],
+            running_envs[current_free_env][2],
+        )
+        player_info = running_envs[current_free_env][1][str(new_running_env[0])]
+        running_envs[current_free_env] = new_running_env
+        if new_running_env[0] < NUMBER_PLAYER_PER_ENV - 1:
+            current_free_env.append(current_free_env)
+        return player_info
+    else:
+        environment_config_path = ROOT_DIR / "game_content" / "environment_config.yaml"
+        layout_path = ROOT_DIR / "game_content" / "layouts" / "basic.layout"
+        item_info_path = ROOT_DIR / "game_content" / "item_info.yaml"
+        with open(item_info_path, "r") as file:
+            item_info = file.read()
+        with open(layout_path, "r") as file:
+            layout = file.read()
+        with open(environment_config_path, "r") as file:
+            environment_config = file.read()
+        creation_json = CreateEnvironmentConfig(
+            manager_id=server_manager_id,
+            number_players=NUMBER_PLAYER_PER_ENV,
+            environment_settings={"all_player_can_pause_game": False},
+            item_info_config=item_info,
+            environment_config=environment_config,
+            layout_config=layout,
+            seed=1234567890,
+        ).model_dump(mode="json")
+        # todo async
+        env_info = requests.post(
+            game_server_url + "/manage/create_env/", json=creation_json
+        )
+
+        if env_info.status_code == 403:
+            raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
+        env_info = env_info.json()
+        print(env_info)
+        running_envs[env_info["env_id"]] = (0, env_info["player_info"], [request_id])
+        current_free_envs.append(env_info["env_id"])
+        return env_info["player_info"]["0"]
+
+
+def main(host, port, game_server_url_, manager_id):
+    global game_server_url, server_manager_id
+    game_server_url = "http://" + game_server_url_
+    server_manager_id = manager_id[0]
+    print(f"Use {server_manager_id=} for {game_server_url=}")
+    loop = asyncio.new_event_loop()
+    config = uvicorn.Config(app, host=host, port=port, loop=loop)
+    server = uvicorn.Server(config)
+    loop.run_until_complete(server.serve())
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        prog="Overcooked Simulator Study Server",
+        description="Study Server: Match Making, client pre and post managing.",
+        epilog="For further information, see https://scs.pages.ub.uni-bielefeld.de/cocosy/overcooked-simulator/overcooked_simulator.html",
+    )
+    url_and_port_arguments(parser=parser, server_name="Study Server", default_port=8080)
+    add_list_of_manager_ids_arguments(parser=parser)
+    args = parser.parse_args()
+    main(
+        args.url,
+        args.port,
+        game_server_url_="localhost:8000",
+        manager_id=args.manager_ids,
+    )
diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
index d54341a5..84879915 100644
--- a/overcooked_simulator/game_server.py
+++ b/overcooked_simulator/game_server.py
@@ -593,17 +593,6 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
             "player_hash" in message_dict
         ), "'player_hash' key not in message dictionary'"
         match request_type:
-            case PlayerRequestType.READY:
-                accepted = environment_handler.set_player_ready(
-                    message_dict["player_hash"]
-                )
-                return {
-                    "request_type": request_type.value,
-                    "msg": f"ready{' ' if accepted else ' not '}accepted",
-                    "status": 200 if accepted else 400,
-                    "player_hash": message_dict["player_hash"],
-                }
-
             case PlayerRequestType.GET_STATE:
                 state = environment_handler.get_state(message_dict["player_hash"])
                 if isinstance(state, int):
@@ -616,7 +605,16 @@ def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResul
                         "player_hash": None,
                     }
                 return state
-
+            case PlayerRequestType.READY:
+                accepted = environment_handler.set_player_ready(
+                    message_dict["player_hash"]
+                )
+                return {
+                    "request_type": request_type.value,
+                    "msg": f"ready{' ' if accepted else ' not '}accepted",
+                    "status": 200 if accepted else 400,
+                    "player_hash": message_dict["player_hash"],
+                }
             case PlayerRequestType.ACTION:
                 assert (
                     "action" in message_dict
diff --git a/overcooked_simulator/gui_2d_vis/drawing.py b/overcooked_simulator/gui_2d_vis/drawing.py
index c9763b3c..fdf3d590 100644
--- a/overcooked_simulator/gui_2d_vis/drawing.py
+++ b/overcooked_simulator/gui_2d_vis/drawing.py
@@ -539,7 +539,9 @@ class Visualizer:
                 burnt=item["type"].startswith("Burnt"),
             )
         elif "content_list" in item and item["content_list"]:
-            triangle_offsets = create_polygon(len(item["content_list"]), length=10)
+            triangle_offsets = create_polygon(
+                len(item["content_list"]), np.array([0.10])
+            )
             scale = 1 if len(item["content_list"]) == 1 else 0.6
             for idx, o in enumerate(item["content_list"]):
                 self.draw_item(
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index 73b14016..f91af37e 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -4,6 +4,7 @@ import json
 import logging
 import random
 import sys
+import uuid
 from enum import Enum
 from subprocess import Popen
 
@@ -30,6 +31,8 @@ from overcooked_simulator.utils import (
     add_list_of_manager_ids_arguments,
 )
 
+CONNECT_WITH_STUDY_SERVER = False
+
 
 class MenuStates(Enum):
     Start = "Start"
@@ -267,7 +270,7 @@ class PyGameGUI:
                         current_player_name, ActionType.INTERACT, InterActionData.STOP
                     )
                     self.send_action(action)
-            if event.key == key_set.switch_key:
+            if event.key == key_set.switch_key and not CONNECT_WITH_STUDY_SERVER:
                 if event.type == pygame.KEYDOWN:
                     key_set.next_player()
 
@@ -683,41 +686,50 @@ class PyGameGUI:
         self.timer_label.set_text(f"Time remaining: {display_time}")
 
     def setup_environment(self):
-        environment_config_path = ROOT_DIR / "game_content" / "environment_config.yaml"
-        layout_path = self.layout_file_paths[self.layout_selection.selected_option]
-        item_info_path = ROOT_DIR / "game_content" / "item_info.yaml"
-        with open(item_info_path, "r") as file:
-            item_info = file.read()
-        with open(layout_path, "r") as file:
-            layout = file.read()
-        with open(environment_config_path, "r") as file:
-            environment_config = file.read()
-
-        seed = 161616161616
-        creation_json = CreateEnvironmentConfig(
-            manager_id=self.manager_id,
-            number_players=self.number_players,
-            environment_settings={"all_player_can_pause_game": False},
-            item_info_config=item_info,
-            environment_config=environment_config,
-            layout_config=layout,
-            seed=seed,
-        ).model_dump(mode="json")
-
-        # print(CreateEnvironmentConfig.model_validate_json(json_data=creation_json))
-        env_info = requests.post(
-            f"{self.request_url}/manage/create_env/",
-            json=creation_json,
-        )
-        if env_info.status_code == 403:
-            raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
-        env_info = env_info.json()
-        assert isinstance(env_info, dict), "Env info must be a dictionary"
-        self.current_env_id = env_info["env_id"]
-        self.player_info = env_info["player_info"]
+        if CONNECT_WITH_STUDY_SERVER:
+            self.player_info = requests.post(
+                f"http://localhost:8080/connect_to_game/{uuid.uuid4().hex}"
+            ).json()
+            self.key_sets[0].current_player = int(self.player_info["player_id"])
+            self.player_info = {self.player_info["player_id"]: self.player_info}
+        else:
+            environment_config_path = (
+                ROOT_DIR / "game_content" / "environment_config.yaml"
+            )
+            layout_path = self.layout_file_paths[self.layout_selection.selected_option]
+            item_info_path = ROOT_DIR / "game_content" / "item_info.yaml"
+            with open(item_info_path, "r") as file:
+                item_info = file.read()
+            with open(layout_path, "r") as file:
+                layout = file.read()
+            with open(environment_config_path, "r") as file:
+                environment_config = file.read()
+
+            seed = 161616161616
+            creation_json = CreateEnvironmentConfig(
+                manager_id=self.manager_id,
+                number_players=self.number_players,
+                environment_settings={"all_player_can_pause_game": False},
+                item_info_config=item_info,
+                environment_config=environment_config,
+                layout_config=layout,
+                seed=seed,
+            ).model_dump(mode="json")
+
+            # print(CreateEnvironmentConfig.model_validate_json(json_data=creation_json))
+            env_info = requests.post(
+                f"{self.request_url}/manage/create_env/",
+                json=creation_json,
+            )
+            if env_info.status_code == 403:
+                raise ValueError(f"Forbidden Request: {env_info.json()['detail']}")
+            env_info = env_info.json()
+            assert isinstance(env_info, dict), "Env info must be a dictionary"
+            self.current_env_id = env_info["env_id"]
+            self.player_info = env_info["player_info"]
 
         state = None
-        for p, (player_id, player_info) in enumerate(env_info["player_info"].items()):
+        for p, (player_id, player_info) in enumerate(self.player_info.items()):
             if p < self.number_humans_to_be_added:
                 websocket = connect(self.websocket_url + player_info["client_id"])
                 websocket.send(
@@ -781,7 +793,7 @@ class PyGameGUI:
 
         if not self.number_humans_to_be_added:
             player_id = "0"
-            player_info = env_info["player_info"][player_id]
+            player_info = self.player_info[player_id]
             websocket = connect(self.websocket_url + player_info["client_id"])
             websocket.send(
                 json.dumps({"type": "ready", "player_hash": player_info["player_hash"]})
@@ -845,28 +857,29 @@ class PyGameGUI:
 
     def reset_button_press(self):
         # self.reset_gui_values()
-
-        requests.post(
-            f"{self.request_url}/manage/stop_env",
-            json={
-                "manager_id": self.manager_id,
-                "env_id": self.current_env_id,
-                "reason": "reset button pressed",
-            },
-        )
+        if not CONNECT_WITH_STUDY_SERVER:
+            requests.post(
+                f"{self.request_url}/manage/stop_env",
+                json={
+                    "manager_id": self.manager_id,
+                    "env_id": self.current_env_id,
+                    "reason": "reset button pressed",
+                },
+            )
 
         # self.websocket.send(json.dumps("reset_game"))
         # answer = self.websocket.recv()        log.debug("Pressed reset button")
 
     def finished_button_press(self):
-        requests.post(
-            f"{self.request_url}/manage/stop_env/",
-            json={
-                "manager_id": self.manager_id,
-                "env_id": self.current_env_id,
-                "reason": "finish button pressed",
-            },
-        )
+        if not CONNECT_WITH_STUDY_SERVER:
+            requests.post(
+                f"{self.request_url}/manage/stop_env/",
+                json={
+                    "manager_id": self.manager_id,
+                    "env_id": self.current_env_id,
+                    "reason": "finish button pressed",
+                },
+            )
         self.menu_state = MenuStates.End
         self.reset_window_size()
         log.debug("Pressed finished button")
@@ -935,9 +948,9 @@ class PyGameGUI:
             json.dumps(
                 {
                     "type": "get_state",
-                    "player_hash": self.player_info[str(self.key_sets[0].current_idx)][
-                        "player_hash"
-                    ],
+                    "player_hash": self.player_info[
+                        str(self.key_sets[0].current_player)
+                    ]["player_hash"],
                 }
             )
         )
@@ -1114,4 +1127,4 @@ if __name__ == "__main__":
     disable_websocket_logging_arguments(parser)
     add_list_of_manager_ids_arguments(parser)
     args = parser.parse_args()
-    main(args.url, args.port, args.manager_ids, args.enable_websocket_logging)
+    main(args.url, args.port, args.manager_ids)
diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py
index b78d44af..e4bcbc28 100644
--- a/overcooked_simulator/utils.py
+++ b/overcooked_simulator/utils.py
@@ -113,21 +113,21 @@ def setup_logging(enable_websocket_logging=False):
         logging.getLogger("websockets.client").setLevel(logging.ERROR)
 
 
-def url_and_port_arguments(parser):
+def url_and_port_arguments(parser, server_name="game server", default_port=8000):
     parser.add_argument(
         "-url",
         "--url",
         "--host",
         type=str,
         default="localhost",
-        help="Overcooked game server host url.",
+        help=f"Overcooked {server_name} host url.",
     )
     parser.add_argument(
         "-p",
         "--port",
         type=int,
-        default=8000,
-        help="Port number for the game engine server",
+        default=default_port,
+        help=f"Port number for the {server_name}",
     )
 
 
-- 
GitLab