From 7ba326c6fb0f7ab53ed8cafe86688020f82e57ab Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Wed, 24 Jan 2024 19:34:48 +0100
Subject: [PATCH] Update server and client communication in game simulator

The server (game_server.py) and client communication (overcooked_gui.py) now uses the WebSocket communication protocol. In setup.py, the 'requests' module was added as a new requirement. The game simulator now waits for a player to be ready before starting, stops a game environment if no step is taken within a minute, and pauses or unpauses a game environment. Player actions are now handled based on a new 'Action' type. The server also now handles several client types that can send messages.
---
 overcooked_simulator/game_server.py           | 203 ++++++++-------
 .../gui_2d_vis/overcooked_gui.py              | 241 +++++++++++-------
 setup.py                                      |   1 +
 3 files changed, 261 insertions(+), 184 deletions(-)

diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
index 9fe683d4..e5d50146 100644
--- a/overcooked_simulator/game_server.py
+++ b/overcooked_simulator/game_server.py
@@ -11,15 +11,14 @@ from datetime import datetime, timedelta
 from enum import Enum
 from typing import Set
 
-import numpy as np
 import uvicorn
 from fastapi import FastAPI
 from fastapi import WebSocket
-from overcooked_simulator.game_server_OLD import setup_logging
 from pydantic import BaseModel
 from starlette.websockets import WebSocketDisconnect
 from typing_extensions import TypedDict
 
+from overcooked_simulator.main import setup_logging
 from overcooked_simulator.overcooked_environment import Action, Environment
 
 log = logging.getLogger(__name__)
@@ -67,13 +66,14 @@ class EnvironmentData:
     last_step_time: int | None = None
 
 
-class GameServer:
+class EnvironmentHandler:
     def __init__(self, env_step_frequency: int = 200):
         self.envs: dict[str, EnvironmentData] = {}
         self.player_data: dict[str, PlayerData] = {}
         self.manager_envs: dict[str, Set[str]] = defaultdict(set)
         self.env_step_frequency = env_step_frequency
         self.preferred_sleep_time_ns = 1e9 / self.env_step_frequency
+        self.client_ids_to_player_hashes = {}
 
     def create_env(self, environment_config: CreateEnvironmentConfig):
         env_id = uuid.uuid4().hex
@@ -92,7 +92,7 @@ class GameServer:
 
         self.manager_envs[environment_config.manager_id].update([env_id])
 
-        return {"env_id": env_id}
+        return {"env_id": env_id, "player_info": player_info}
 
     def create_player(self, env, env_id, player_id):
         player_hash = uuid.uuid4().hex
@@ -103,6 +103,7 @@ class GameServer:
             websocket_id=client_id,
         )
         self.player_data[player_hash] = player_data
+        self.client_ids_to_player_hashes[client_id] = player_hash
         env.add_player(player_id)
 
         return {
@@ -135,8 +136,15 @@ class GameServer:
             self.envs[env_id].last_step_time = time.time_ns()
             self.envs[env_id].environment.reset_env_time()
 
-    def get_state(self):
-        ...
+    def get_state(self, player_hash: str):
+        if (
+            player_hash in self.player_data
+            and self.player_data[player_hash].env_id in self.envs
+        ):
+            # TODO normal json state
+            return self.envs[
+                self.player_data[player_hash].env_id
+            ].environment.get_state_simple_json()
 
     def pause_env(self, manager_id: str, env_id: str, reason: str):
         if (
@@ -166,35 +174,25 @@ class GameServer:
             self.envs[env_id].status = EnvironmentStatus.STOPPED
             self.envs[env_id].stop_reason = reason
 
-    def set_player_ready(self, env_id: str, player_hash, player_id: int):
-        if (
-            player_hash in self.player_data
-            and self.player_data[player_hash].player_id == player_id
-            and self.player_data[player_hash].env_id == env_id
-        ):
+    def set_player_ready(self, player_hash):
+        if player_hash in self.player_data:
             self.player_data[player_hash].ready = True
             return True
         return False
 
-    def set_player_connected(self, env_id: str, player_hash, player_id: int) -> bool:
-        if (
-            player_hash in self.player_data
-            and self.player_data[player_hash].player_id == player_id
-            and self.player_data[player_hash].env_id == env_id
-        ):
-            self.player_data[player_hash].connected = True
+    def set_player_connected(self, client_id: str) -> bool:
+        if client_id in self.client_ids_to_player_hashes:
+            self.player_data[
+                self.client_ids_to_player_hashes[client_id]
+            ].connected = True
             return True
         return False
 
-    def set_player_disconnected(
-        self, env_id: str, player_hash: str, player_id: int
-    ) -> bool:
-        if (
-            player_hash in self.player_data
-            and self.player_data[player_hash].player_id == player_id
-            and self.player_data[player_hash].env_id == env_id
-        ):
-            self.player_data[player_hash].connected = False
+    def set_player_disconnected(self, client_id: str) -> bool:
+        if client_id in self.client_ids_to_player_hashes:
+            self.player_data[
+                self.client_ids_to_player_hashes[client_id]
+            ].connected = False
             return True
         return False
 
@@ -228,9 +226,11 @@ class GameServer:
             ]
 
     async def environment_steps(self):
+        # TODO environment dependent steps.
         overslept_in_ns = 0
         while True:
             pre_step_start = time.time_ns()
+            to_remove = []
             for env_id, env_data in self.envs.items():
                 if env_data.status == EnvironmentStatus.RUNNING:
                     step_start = time.time_ns()
@@ -241,6 +241,15 @@ class GameServer:
                         )
                     )
                     env_data.last_step_time = step_start
+                elif (
+                    env_data.status == EnvironmentStatus.STOPPED
+                    and env_data.last_step_time + (60 * 1e9) < pre_step_start
+                ):
+                    to_remove.append(env_id)
+
+            if to_remove:
+                for env_id in to_remove:
+                    del self.envs[env_id]
             step_duration = time.time_ns() - pre_step_start
 
             time_to_sleep_ns = self.preferred_sleep_time_ns - (
@@ -252,21 +261,36 @@ class GameServer:
             sleep_function_duration = time.time_ns() - sleep_start
             overslept_in_ns = sleep_function_duration - time_to_sleep_ns
 
+    def is_known_client_id(self, client_id: str) -> bool:
+        return client_id in self.client_ids_to_player_hashes
+
+    def player_action(self, player_hash: str, action: Action):
+        if (
+            player_hash in self.player_data
+            and action.player == self.player_data[player_hash].player_id
+            and self.player_data[player_hash].env_id in self.envs
+            and player_hash
+            in self.envs[self.player_data[player_hash].env_id].player_hashes
+        ):
+            self.envs[self.player_data[player_hash].env_id].environment.perform_action(
+                action
+            )
+
 
 class PlayerConnectionManager:
     def __init__(self):
         self.player_connections: dict[str, WebSocket] = {}
 
-    async def connect_player(self, websocket: WebSocket, player_id: str) -> bool:
-        if player_id not in self.player_connections:
+    async def connect_player(self, websocket: WebSocket, client_id: str) -> bool:
+        if client_id not in self.player_connections:
             await websocket.accept()
-            self.player_connections[player_id] = websocket
+            self.player_connections[client_id] = websocket
             return True
         return False
 
-    def disconnect(self, id_: str):
-        if id_ in self.player_connections:
-            del self.player_connections[id_]
+    def disconnect(self, client_id: str):
+        if client_id in self.player_connections:
+            del self.player_connections[client_id]
 
     @staticmethod
     async def send_personal_message(message: str, websocket: WebSocket):
@@ -278,49 +302,43 @@ class PlayerConnectionManager:
 
 
 manager = PlayerConnectionManager()
-oc_api: GameServer = GameServer()
-
-
-def parse_websocket_action(message: str) -> Action:
-    if message.replace('"', "") != "get_state":
-        message_dict = json.loads(message)
-        if message_dict["act_type"] == "movement":
-            if isinstance(message_dict["value"], list):
-                x, y = message_dict["value"]
-            elif isinstance(message_dict["value"], str):
-                x, y = (
-                    message_dict["value"]
-                    .replace(" ", "")
-                    .replace("[", "")
-                    .replace("]", "")
-                    .split(",")
-                )
-            else:
-                x, y = 0, 0
-            value = np.array([x, y], dtype=float)
-        else:
-            value = None
-        action = Action(
-            message_dict["player_name"],
-            message_dict["act_type"],
-            value,
-            duration=message_dict["duration"],
-        )
-        return action
-
+environment_handler: EnvironmentHandler = EnvironmentHandler()
 
-def manage_websocket_message(message: str):
-    if "get_state" in message:
-        return oc_api.get_state()
 
-    if "reset_game" in message:
-        oc_api.reset_game()
-        return "Reset game."
-
-    action = parse_websocket_action(message)
-    oc_api.simulator.enter_action(action)
-    answer = oc_api.get_state()
-    return answer
+@dataclasses.dataclass
+class PlayerAction:
+    player_hash: str
+    action: Action
+
+
+def manage_websocket_message(message: str, client_id: str):
+    message_dict = json.loads(message)
+
+    assert "type" in message_dict, "message needs a type"
+
+    match message_dict["type"]:
+        case "ready":
+            assert "player_hash" in message_dict, "needs player hash for ready"
+            environment_handler.set_player_ready(message_dict["player_hash"])
+            return {
+                "status": "ready accepted",
+                "player_hash": message_dict["player_hash"],
+            }
+        case "get_state":
+            assert "player_hash" in message_dict, "needs player hash for environment"
+            return environment_handler.get_state(message_dict["player_hash"])
+        case "action":
+            assert "action" in message_dict, "action type needs action data"
+            assert "player_hash" in message_dict, "action type needs player hash"
+            environment_handler.player_action(
+                message_dict["player_hash"], Action(**message_dict["action"])
+            )
+            return {
+                "status": "action accepted",
+                "player_hash": message_dict["player_hash"],
+            }
+    # TODO setup error enums or class
+    return {"status": "error", "info": "unknown message type"}
 
 
 @app.get("/")
@@ -345,44 +363,55 @@ class AdditionalPlayer(BaseModel):
     existing_websocket: str | None = None
 
 
-@app.post("/manage/create_env")
-async def register_manger(creation: CreateEnvironmentConfig):
-    result = oc_api.create_env(creation)
+@app.post("/manage/create_env/")
+async def create_env(creation: CreateEnvironmentConfig):
+    print(creation)
+    result = environment_handler.create_env(creation)
     return result
 
 
-@app.post("/manage/additional_player")
+@app.post("/manage/additional_player/")
 async def additional_player(creation: AdditionalPlayer):
-    result = oc_api.add_player(creation)
+    result = environment_handler.add_player(creation)
     return result
 
 
-@app.post("manage/stop_env")
+@app.post("/manage/stop_env/")
 async def stop_env(manager_id: str, env_id: str, reason: str):
-    result = oc_api.stop_env(manager_id, env_id, reason)
+    result = environment_handler.stop_env(manager_id, env_id, reason)
     return result
 
 
+# pause / unpause
 # control access / functions / data
 
 
 @app.websocket("/ws/player/{client_id}")
-async def websocket_player_endpoint(websocket: WebSocket, client_id: int):
-    await manager.connect(websocket)
+async def websocket_player_endpoint(websocket: WebSocket, client_id: str):
+    if client_id not in environment_handler.is_known_client_id(client_id):
+        return
+    await manager.connect_player(websocket, client_id)
     log.debug(f"Client #{client_id} connected")
+    environment_handler.set_player_connected(client_id)
     try:
         while True:
             message = await websocket.receive_text()
-            answer = manage_websocket_message(message)
+            answer = manage_websocket_message(message, client_id)
             await manager.send_personal_message(answer, websocket)
 
     except WebSocketDisconnect:
-        manager.disconnect(websocket)
+        manager.disconnect(client_id)
+        environment_handler.set_player_disconnected(client_id)
         log.debug(f"Client #{client_id} disconnected")
 
 
 def main():
-    uvicorn.run(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT)
+    loop = asyncio.new_event_loop()
+    asyncio.set_event_loop(loop)
+    loop.create_task(environment_handler.environment_steps())
+    config = uvicorn.Config(app, host=WEBSOCKET_URL, port=WEBSOCKET_PORT, loop=loop)
+    server = uvicorn.Server(config)
+    loop.run_until_complete(server.serve())
 
 
 if __name__ == "__main__":
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index 65e55d49..0c70cac4 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -3,6 +3,7 @@ import json
 import logging
 import math
 import sys
+import time
 from datetime import timedelta
 from enum import Enum
 
@@ -10,6 +11,7 @@ import numpy as np
 import numpy.typing as npt
 import pygame
 import pygame_gui
+import requests
 import yaml
 from scipy.spatial import KDTree
 from websockets.sync.client import connect
@@ -20,6 +22,7 @@ from overcooked_simulator.game_items import (
     CookingEquipment,
     Plate,
 )
+from overcooked_simulator.game_server import CreateEnvironmentConfig
 from overcooked_simulator.gui_2d_vis.game_colors import BLUE
 from overcooked_simulator.gui_2d_vis.game_colors import colors, Color
 from overcooked_simulator.order import Order
@@ -36,6 +39,9 @@ class MenuStates(Enum):
     End = "End"
 
 
+MANAGER_ID = "1233245425"
+
+
 def create_polygon(n, length):
     if n == 1:
         return np.array([0, 0])
@@ -107,7 +113,8 @@ class PyGameGUI:
         ]
 
         # self.websocket_url = "ws://localhost:8765"
-        self.websocket_url = "ws://localhost:8000/ws/29"
+        self.websocket_url = "ws://localhost:8000/ws/"
+        self.websockets = {}
 
         # TODO cache loaded images?
         with open(ROOT_DIR / "gui_2d_vis" / "visualization.yaml", "r") as file:
@@ -821,8 +828,50 @@ class PyGameGUI:
     def start_button_press(self):
         self.menu_state = MenuStates.Game
 
-        with connect(self.websocket_url) as websocket:
-            state = self.request_state()
+        environment_config_path = ROOT_DIR / "game_content" / "environment_config.yaml"
+        layout_path = ROOT_DIR / "game_content" / "layouts" / "basic.layout"
+        item_info_path = ROOT_DIR / "game_content" / "item_info.yaml"
+        with open(item_info_path, "r") as file:
+            item_info = file.read()
+        with open(layout_path, "r") as file:
+            layout = file.read()
+        with open(environment_config_path, "r") as file:
+            environment_config = file.read()
+        print(
+            CreateEnvironmentConfig(
+                manager_id=MANAGER_ID,
+                number_players=2,
+                environment_settings={"all_player_can_pause_game": False},
+                item_info_config=item_info,
+                environment_config=environment_config,
+                layout_config=layout,
+            ).model_dump_json()
+        )
+        env_info = requests.post(
+            "http://localhost:8000/manage/create_env/",
+            json=CreateEnvironmentConfig(
+                manager_id=MANAGER_ID,
+                number_players=2,
+                environment_settings={"all_player_can_pause_game": False},
+                item_info_config=item_info,
+                environment_config=environment_config,
+                layout_config=layout,
+            ).model_dump_json(),
+        )
+        print(env_info)
+        assert isinstance(env_info, dict), "Env info must be a dictionary"
+        self.current_env_id = env_info["env_id"]
+        self.player_info = env_info["player_info"]
+        for player_id, player_info in env_info["player_info"].items():
+            websocket = connect(self.websocket_url + player_info["client_id"])
+            websocket.send({"type": "ready", "player_hash": player_info["player_hash"]})
+            self.websockets[player_id] = websocket
+        time.sleep(0.1)
+
+        state = websocket.send(
+            {"type": "get_state", "player_hash": player_info["player_hash"]}
+        )
+        self.state_player_id = player_id
 
         (
             self.window_width,
@@ -849,7 +898,14 @@ class PyGameGUI:
         log.debug("Pressed quit button")
 
     def reset_button_press(self):
-        _ = self.websocket_communicate("reset_game")
+        requests.post(
+            "http://localhost:8000/manage/stop_env",
+            json={
+                "manager_id": MANAGER_ID,
+                "env_id": self.current_env_id,
+                "reason": "reset button pressed",
+            },
+        )
 
         # self.websocket.send(json.dumps("reset_game"))
         # answer = self.websocket.recv()
@@ -867,31 +923,28 @@ class PyGameGUI:
             action: The action to be sent. Contains the player, action type and move direction if action is a movement.
         """
         if isinstance(action.action, np.ndarray):
-            value = [float(action.action[0]), float(action.action[1])]
+            action.action = [float(action.action[0]), float(action.action[1])]
         else:
-            value = action.action
-        message_dict = {
-            "player_name": action.player,
-            "act_type": action.act_type,
-            "value": value,
-            "duration": action.duration,
-        }
-        _ = self.websocket_communicate(message_dict)
-
-    def websocket_communicate(self, message_dict: dict | str):
-        self.websocket.send(json.dumps(message_dict))
-        answer = self.websocket.recv()
-        try:
-            answer = json.loads(answer)
-        except json.decoder.JSONDecodeError:
-            answer = None
-        return answer
+            action.action = action.action
+        ret = self.websockets[action.player].send(
+            {
+                "type": "action",
+                "action": action,
+                "player_hash": self.player_info[action.player]["player_hash"],
+            }
+        )
+        print(ret)
 
     def request_state(self):
-        state_dict = self.websocket_communicate("get_state")
+        state_dict = self.websockets[self.state_player_id].send(
+            {
+                "type": "get_state",
+                "player_hash": self.player_info[self.state_player_id]["player_hash"],
+            }
+        )
         # self.websocket.send(json.dumps("get_state"))
         # state_dict = json.loads(self.websocket.recv())
-        return state_dict
+        return json.loads(state_dict)
 
     def start_pygame(self):
         """Starts pygame and the gui loop. Each frame the game state is visualized and keyboard inputs are read."""
@@ -907,97 +960,91 @@ class PyGameGUI:
         self.init_ui_elements()
         self.manage_button_visibility()
 
-        with connect(self.websocket_url) as websocket:
-            self.websocket = websocket
-            # Game loop
-            self.running = True
-            while self.running:
-                try:
-                    time_delta = clock.tick(self.FPS) / 1000.0
-
-                    for event in pygame.event.get():
-                        if event.type == pygame.QUIT:
-                            self.running = False
-
-                        # UI Buttons:
-                        if event.type == pygame_gui.UI_BUTTON_PRESSED:
-                            match event.ui_element:
-                                case self.start_button:
-                                    self.start_button_press()
-                                case self.back_button:
-                                    self.start_button_press()
-                                case self.finished_button:
-                                    self.finished_button_press()
-                                case self.quit_button:
-                                    self.quit_button_press()
-                                case self.reset_button:
-                                    self.reset_button_press()
-                                    self.start_button_press()
+        # Game loop
+        self.running = True
+        while self.running:
+            try:
+                time_delta = clock.tick(self.FPS) / 1000.0
+
+                for event in pygame.event.get():
+                    if event.type == pygame.QUIT:
+                        self.running = False
+
+                    # UI Buttons:
+                    if event.type == pygame_gui.UI_BUTTON_PRESSED:
+                        match event.ui_element:
+                            case self.start_button:
+                                self.start_button_press()
+                            case self.back_button:
+                                self.start_button_press()
+                            case self.finished_button:
+                                self.finished_button_press()
+                            case self.quit_button:
+                                self.quit_button_press()
+                            case self.reset_button:
+                                self.reset_button_press()
+                                self.start_button_press()
 
-                            self.manage_button_visibility()
+                        self.manage_button_visibility()
 
-                        if (
-                            event.type in [pygame.KEYDOWN, pygame.KEYUP]
-                            and self.menu_state == MenuStates.Game
-                        ):
-                            pass
-                            self.handle_key_event(event)
+                    if (
+                        event.type in [pygame.KEYDOWN, pygame.KEYUP]
+                        and self.menu_state == MenuStates.Game
+                    ):
+                        pass
+                        self.handle_key_event(event)
 
-                        self.manager.process_events(event)
+                    self.manager.process_events(event)
 
-                    # drawing:
+                # drawing:
 
-                    # state = self.simulator.get_state()
+                # state = self.simulator.get_state()
 
-                    self.main_window.fill(
-                        colors[
-                            self.visualization_config["GameWindow"]["background_color"]
-                        ]
-                    )
-                    self.manager.draw_ui(self.main_window)
+                self.main_window.fill(
+                    colors[self.visualization_config["GameWindow"]["background_color"]]
+                )
+                self.manager.draw_ui(self.main_window)
 
-                    match self.menu_state:
-                        case MenuStates.Start:
-                            pass
+                match self.menu_state:
+                    case MenuStates.Start:
+                        pass
 
-                        case MenuStates.Game:
-                            state = self.request_state()
+                    case MenuStates.Game:
+                        state = self.request_state()
 
-                            self.draw_background()
+                        self.draw_background()
 
-                            self.handle_keys()
+                        self.handle_keys()
 
-                            # state = self.simulator.get_state()
-                            self.draw(state)
+                        # state = self.simulator.get_state()
+                        self.draw(state)
 
-                            if state["ended"]:
-                                self.finished_button_press()
-                                self.manage_button_visibility()
-                            else:
-                                self.draw(state)
+                        if state["ended"]:
+                            self.finished_button_press()
+                            self.manage_button_visibility()
+                        else:
+                            self.draw(state)
 
-                                game_screen_rect = self.game_screen.get_rect()
-                                game_screen_rect.center = [
-                                    self.window_width // 2,
-                                    self.window_height // 2,
-                                ]
+                            game_screen_rect = self.game_screen.get_rect()
+                            game_screen_rect.center = [
+                                self.window_width // 2,
+                                self.window_height // 2,
+                            ]
 
-                                self.main_window.blit(
-                                    self.game_screen, game_screen_rect
-                                )
+                            self.main_window.blit(self.game_screen, game_screen_rect)
 
-                        case MenuStates.End:
-                            self.update_conclusion_label(state)
+                    case MenuStates.End:
+                        self.update_conclusion_label(state)
 
-                    self.manager.update(time_delta)
-                    pygame.display.flip()
+                self.manager.update(time_delta)
+                pygame.display.flip()
 
-                except KeyboardInterrupt:
-                    pygame.quit()
-                    sys.exit()
+            except KeyboardInterrupt:
+                pygame.quit()
+                sys.exit()
 
-            pygame.quit()
-            sys.exit()
+        pygame.quit()
+        sys.exit()
 
 
 def main():
diff --git a/setup.py b/setup.py
index 55c09597..42a86447 100644
--- a/setup.py
+++ b/setup.py
@@ -20,6 +20,7 @@ requirements = [
     "fastapi",
     "uvicorn",
     "websockets",
+    "requests",
 ]
 
 test_requirements = [
-- 
GitLab