From 85bbed68b37a674308e9109802a45a50009c2d33 Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.uni-bielefeld.de>
Date: Tue, 5 Mar 2024 09:20:57 +0100
Subject: [PATCH] TypedDict for websocket messages

---
 cooperative_cuisine/action.py            | 10 +++++-
 cooperative_cuisine/game_server.py       |  8 ++++-
 cooperative_cuisine/pygame_2d_vis/gui.py | 43 ++++++++++--------------
 cooperative_cuisine/study_server.py      | 37 +++++++++++---------
 4 files changed, 55 insertions(+), 43 deletions(-)

diff --git a/cooperative_cuisine/action.py b/cooperative_cuisine/action.py
index a080070b..b35819ef 100644
--- a/cooperative_cuisine/action.py
+++ b/cooperative_cuisine/action.py
@@ -2,7 +2,7 @@ from __future__ import annotations
 
 import dataclasses
 from enum import Enum
-from typing import Literal
+from typing import Literal, TypedDict
 
 from numpy import typing as npt
 
@@ -49,3 +49,11 @@ class Action:
             self.action_type = ActionType(self.action_type)
         if isinstance(self.action_data, str) and self.action_data != "pickup":
             self.action_data = InterActionData(self.action_data)
+
+
+class ActionDict(TypedDict):
+    """Typed dict corresponding to the Action dataclass."""
+    player: str
+    action_type: ActionType
+    action_data: list[float] | InterActionData | Literal["pickup"]
+    duration: float
diff --git a/cooperative_cuisine/game_server.py b/cooperative_cuisine/game_server.py
index b8df259a..66922d2e 100644
--- a/cooperative_cuisine/game_server.py
+++ b/cooperative_cuisine/game_server.py
@@ -29,7 +29,7 @@ from pydantic import BaseModel
 from starlette.websockets import WebSocketDisconnect
 from typing_extensions import TypedDict
 
-from cooperative_cuisine.action import Action
+from cooperative_cuisine.action import Action, ActionDict
 from cooperative_cuisine.environment import Environment
 from cooperative_cuisine.server_results import (
     CreateEnvResult,
@@ -591,6 +591,12 @@ class PlayerRequestType(Enum):
     """Indicates a request to pass an action of a player to the environment."""
 
 
+class WebsocketMessage(TypedDict):
+    type: str
+    action: None | ActionDict
+    player_hash: str
+
+
 def manage_websocket_message(message: str, client_id: str) -> PlayerRequestResult | str:
     """Manage WebSocket Message by validating the message and passing it to the environment.
 
diff --git a/cooperative_cuisine/pygame_2d_vis/gui.py b/cooperative_cuisine/pygame_2d_vis/gui.py
index 804315bb..dee2e037 100644
--- a/cooperative_cuisine/pygame_2d_vis/gui.py
+++ b/cooperative_cuisine/pygame_2d_vis/gui.py
@@ -21,7 +21,11 @@ from websockets.sync.client import connect
 
 from cooperative_cuisine import ROOT_DIR
 from cooperative_cuisine.action import ActionType, InterActionData, Action
-from cooperative_cuisine.game_server import CreateEnvironmentConfig
+from cooperative_cuisine.game_server import (
+    CreateEnvironmentConfig,
+    WebsocketMessage,
+    PlayerRequestType,
+)
 from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer
 from cooperative_cuisine.pygame_2d_vis.game_colors import colors
 from cooperative_cuisine.state_representation import StateRepresentation
@@ -1366,12 +1370,11 @@ class PyGameGUI:
             if p < self.number_humans_to_be_added:
                 # add player websockets
                 websocket = connect(self.websocket_url + player_info["client_id"])
-
-                websocket.send(
-                    json.dumps(
-                        {"type": "ready", "player_hash": player_info["player_hash"]}
-                    )
+                ws_message = WebsocketMessage(
+                    type=PlayerRequestType.READY.value,
+                    player_hash=player_info["player_hash"],
                 )
+                websocket.send(json.dumps(ws_message))
                 assert (
                     json.loads(websocket.recv())["status"] == 200
                 ), "not accepted player"
@@ -1486,30 +1489,20 @@ class PyGameGUI:
                 float(action.action_data[1]),
             ]
 
-        self.websockets[action.player].send(
-            json.dumps(
-                {
-                    "type": "action",
-                    "action": dataclasses.asdict(
-                        action, dict_factory=custom_asdict_factory
-                    ),
-                    "player_hash": self.player_info[action.player]["player_hash"],
-                }
-            )
+        ws_message = WebsocketMessage(
+            type=PlayerRequestType.ACTION.value,
+            action=dataclasses.asdict(action, dict_factory=custom_asdict_factory),
+            player_hash=self.player_info[action.player]["player_hash"],
         )
+        self.websockets[action.player].send(json.dumps(ws_message))
         self.websockets[action.player].recv()
 
     def request_state(self):
-        self.websockets[self.state_player_id].send(
-            json.dumps(
-                {
-                    "type": "get_state",
-                    "player_hash": self.player_info[self.state_player_id][
-                        "player_hash"
-                    ],
-                }
-            )
+        ws_message = WebsocketMessage(
+            type=PlayerRequestType.GET_STATE.value,
+            player_hash=self.player_info[self.state_player_id]["player_hash"],
         )
+        self.websockets[self.state_player_id].send(json.dumps(ws_message))
         state = json.loads(self.websockets[self.state_player_id].recv())
         return state
 
diff --git a/cooperative_cuisine/study_server.py b/cooperative_cuisine/study_server.py
index 793776a9..06b7b2cc 100644
--- a/cooperative_cuisine/study_server.py
+++ b/cooperative_cuisine/study_server.py
@@ -31,7 +31,7 @@ from typing_extensions import TypedDict
 
 from cooperative_cuisine import ROOT_DIR
 from cooperative_cuisine.environment import EnvironmentConfig
-from cooperative_cuisine.game_server import CreateEnvironmentConfig
+from cooperative_cuisine.game_server import CreateEnvironmentConfig, EnvironmentData
 from cooperative_cuisine.server_results import PlayerInfo
 from cooperative_cuisine.utils import (
     url_and_port_arguments,
@@ -86,16 +86,20 @@ class StudyState:
         self.study_config: StudyConfig = yaml.load(
             str(env_config_f), Loader=yaml.SafeLoader
         )
+        """Configuration for the study which layouts, env_configs and item infos are used for the study levels."""
         self.levels: list[LevelConfig] = self.study_config["levels"]
+        """List of level configs for each of the levels which the study runs through."""
         self.current_level_idx: int = 0
-
+        """Counter of which level is currently run in the config."""
         self.participant_id_to_player_info = {}
-        self.player_ids = {}
+        """A dictionary which maps participants to player infos."""
         self.num_connected_players: int = 0
+        """Number of currently connected players."""
 
-        self.current_running_env = None
-        self.next_level_env = None
+        self.current_running_env: EnvironmentData | None = None
+        """Information about the current running environment."""
         self.players_done = {}
+        """A dictionary which saves which player has sent ready."""
 
         self.use_aaambos_agent = False
 
@@ -119,7 +123,7 @@ class StudyState:
             len(self.participant_id_to_player_info) == self.study_config["num_players"]
         )
 
-    def can_add_participant(self, num_participants: int) -> bool:
+    def can_add_participants(self, num_participants: int) -> bool:
         filled = (
             self.num_connected_players + num_participants
             <= self.study_config["num_players"]
@@ -324,12 +328,12 @@ class StudyManager:
 
     def add_participant(self, participant_id: str, number_players: int):
         if not self.running_studies or all(
-            [not s.can_add_participant(number_players) for s in self.running_studies]
+            [not s.can_add_participants(number_players) for s in self.running_studies]
         ):
             self.create_study()
 
         for study in self.running_studies:
-            if study.can_add_participant(number_players):
+            if study.can_add_participants(number_players):
                 player_info = study.add_participant(participant_id, number_players)
                 self.participant_id_to_study_map[participant_id] = study
                 return True, player_info
@@ -345,13 +349,15 @@ class StudyManager:
 
     def get_participant_game_connection(
         self, participant_id: str
-    ) -> Tuple[PlayerInfo | None, LevelInfo | None]:
+    ) -> Tuple[bool, None | Tuple[PlayerInfo, LevelInfo]]:
+        can_connect = False
         if participant_id in self.participant_id_to_study_map.keys():
             assigned_study = self.participant_id_to_study_map[participant_id]
             player_info, level_info = assigned_study.get_connection(participant_id)
-            return player_info, level_info
+            can_connect = True
+            return can_connect, (player_info, level_info)
         else:
-            return None, None
+            return can_connect, None
 
     def set_game_server_url(self, game_host: str, game_port: str):
         self.game_host = game_host
@@ -393,15 +399,14 @@ async def level_done(participant_id: str) -> JSONResponse:
 
 @app.post("/get_game_connection/{participant_id}")
 async def get_game_connection(participant_id: str) -> JSONResponse:
-    player_info, level_info = study_manager.get_participant_game_connection(
-        participant_id
-    )
-    if player_info and level_info:
+    can_connect, data = study_manager.get_participant_game_connection(participant_id)
+    if can_connect:
+        player_info, level_info = data
         return JSONResponse(
             content={"player_info": player_info, "level_info": level_info}
         )
     else:
-        raise HTTPException(status_code=409, detail="Not valid game connection.")
+        raise HTTPException(status_code=409, detail="No valid game connection.")
 
 
 @app.post("/connect_to_tutorial/{participant_id}")
-- 
GitLab