From 47edb3eaa56ae5b2e8989bdce57debbeb54eb1ef Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Thu, 25 Jan 2024 20:57:32 +0100
Subject: [PATCH] Refactor player id type and remove global constants in UI

The player id type is updated from 'int | str' to 'str' across different files to maintain code uniformity and avoid type-related errors. Removed the global constants (USE_PLAYER_COOK_SPRITES, SHOW_INTERACTION_RANGE, SHOW_COUNTER_CENTERS) from gui files for better code structure. Moreover, function annotations are improved by providing more specific types which helps in understanding the codebase. The diff adds additional details in the form of questions or short descriptions in the comments as well.
---
 overcooked_simulator/game_server.py           | 11 +++----
 overcooked_simulator/gui_2d_vis/drawing.py    | 29 +++++++++++--------
 .../gui_2d_vis/overcooked_gui.py              |  7 -----
 .../overcooked_environment.py                 |  4 +--
 overcooked_simulator/player.py                |  7 +++--
 5 files changed, 29 insertions(+), 29 deletions(-)

diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py
index 92aaad5a..cedf5560 100644
--- a/overcooked_simulator/game_server.py
+++ b/overcooked_simulator/game_server.py
@@ -35,7 +35,7 @@ WEBSOCKET_PORT = 8000
 
 @dataclasses.dataclass
 class PlayerData:
-    player_id: int
+    player_id: str
     env_id: str
     websocket_id: str | None = None
     connected: bool = False
@@ -99,7 +99,7 @@ class EnvironmentHandler:
 
         return {"env_id": env_id, "player_info": player_info}
 
-    def create_player(self, env, env_id, player_id):
+    def create_player(self, env: Environment, env_id: str, player_id: str) -> dict:
         player_hash = uuid.uuid4().hex
         client_id = uuid.uuid4().hex
         player_data = PlayerData(
@@ -126,6 +126,7 @@ class EnvironmentHandler:
         ):
             n_players = len(self.envs[config.env_id].player_hashes)
             for player_id in range(n_players, n_players + config.number_players):
+                player_id = str(player_id)
                 new_player_info[player_id] = self.create_player(
                     env=self.envs[config.env_id].environment,
                     env_id=config.env_id,
@@ -217,7 +218,7 @@ class EnvironmentHandler:
             for player_hash in self.envs[env_id].player_hashes
         )
 
-    def get_not_connected_players(self, env_id: str) -> list[int]:
+    def get_not_connected_players(self, env_id: str) -> list[str]:
         if env_id in self.envs:
             return [
                 self.player_data[player_hash].player_id
@@ -225,7 +226,7 @@ class EnvironmentHandler:
                 if not self.player_data[player_hash].connected
             ]
 
-    def get_not_ready_players(self, env_id: str) -> list[int]:
+    def get_not_ready_players(self, env_id: str) -> list[str]:
         if env_id in self.envs:
             return [
                 self.player_data[player_hash].player_id
@@ -365,7 +366,7 @@ def read_root():
 class CreateEnvironmentConfig(BaseModel):
     manager_id: str
     number_players: int
-    same_websocket_player: list[list[int]] | None = None
+    same_websocket_player: list[list[str]] | None = None
     environment_settings: EnvironmentSettings
     item_info_config: str
     environment_config: str
diff --git a/overcooked_simulator/gui_2d_vis/drawing.py b/overcooked_simulator/gui_2d_vis/drawing.py
index 00f5990c..11ee0c93 100644
--- a/overcooked_simulator/gui_2d_vis/drawing.py
+++ b/overcooked_simulator/gui_2d_vis/drawing.py
@@ -10,6 +10,15 @@ from scipy.spatial import KDTree
 
 from overcooked_simulator import ROOT_DIR
 from overcooked_simulator.gui_2d_vis.game_colors import colors
+from overcooked_simulator.state_representation import (
+    PlayerState,
+    CookingEquipmentState,
+    ItemState,
+)
+
+USE_PLAYER_COOK_SPRITES = True
+SHOW_INTERACTION_RANGE = False
+SHOW_COUNTER_CENTERS = False
 
 
 def create_polygon(n, length):
@@ -60,9 +69,6 @@ class Visualizer:
         width,
         height,
         grid_size,
-        SHOW_COUNTER_CENTERS=False,
-        USE_PLAYER_COOK_SPRITES=False,
-        SHOW_INTERACTION_RANGE=False,
     ):
         self.draw_background(
             surface=screen,
@@ -74,15 +80,12 @@ class Visualizer:
             screen,
             state,
             grid_size,
-            SHOW_COUNTER_CENTERS,
         )
 
         self.draw_players(
             screen,
             state,
             grid_size,
-            USE_PLAYER_COOK_SPRITES=USE_PLAYER_COOK_SPRITES,
-            SHOW_INTERACTION_RANGE=SHOW_INTERACTION_RANGE,
         )
 
     def draw_background(self, surface, width, height, grid_size):
@@ -129,14 +132,13 @@ class Visualizer:
         screen: pygame.Surface,
         state_dict: dict,
         grid_size: float,
-        USE_PLAYER_COOK_SPRITES: bool = True,
-        SHOW_INTERACTION_RANGE: bool = False,
     ):
         """Visualizes the players as circles with a triangle for the facing direction.
         If the player holds something in their hands, it is displayed
         Args:            state: The game state returned by the environment.
         """
         for p_idx, player_dict in enumerate(state_dict["players"]):
+            player_dict: PlayerState
             pos = np.array(player_dict["pos"]) * grid_size
 
             facing = np.array(player_dict["facing_direction"])
@@ -225,6 +227,8 @@ class Visualizer:
         """Draws an item, based on its visual parts specified in the visualization config.
 
         Args:
+            screen: the game screen to draw on.
+            grid_size: size of a grid cell.
             pos: Where to draw the item parts.
             parts: The visual parts to draw.
             scale: Rescale the item by this factor.
@@ -275,9 +279,9 @@ class Visualizer:
 
     def draw_item(
         self,
-        pos: npt.NDArray[float],
-        grid_size,
-        item,
+        pos: npt.NDArray[float] | list[float],
+        grid_size: float,
+        item: ItemState | CookingEquipmentState,
         scale: float = 1.0,
         plate=False,
         screen=None,
@@ -287,6 +291,7 @@ class Visualizer:
         different parts to be drawn.
 
         Args:
+            grid_size: size of a grid cell.
             pos: The position of the item to draw.
             item: The item do be drawn in the game.
             scale: Rescale the item by this factor.
@@ -294,7 +299,7 @@ class Visualizer:
             plate: item is on a plate (soup are is different on a plate and pot)
         """
 
-        if not isinstance(item, list):
+        if not isinstance(item, list):  # can we remove this check?
             if item["type"] in self.config:
                 item_key = item["type"]
                 if "Soup" in item_key and plate:
diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
index e8ee5df8..5dd4326f 100644
--- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py
+++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py
@@ -18,10 +18,6 @@ from overcooked_simulator.gui_2d_vis.drawing import Visualizer
 from overcooked_simulator.gui_2d_vis.game_colors import colors
 from overcooked_simulator.overcooked_environment import Action
 
-USE_PLAYER_COOK_SPRITES = True
-SHOW_INTERACTION_RANGE = False
-SHOW_COUNTER_CENTERS = False
-
 
 class MenuStates(Enum):
     Start = "Start"
@@ -332,9 +328,6 @@ class PyGameGUI:
             self.game_width,
             self.game_height,
             self.grid_size,
-            SHOW_COUNTER_CENTERS,
-            USE_PLAYER_COOK_SPRITES,
-            SHOW_INTERACTION_RANGE,
         )
 
         # self.manager.draw_ui(self.main_window)
diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py
index d4b2bdee..84336327 100644
--- a/overcooked_simulator/overcooked_environment.py
+++ b/overcooked_simulator/overcooked_environment.py
@@ -550,7 +550,7 @@ class Environment:
         distance = np.linalg.norm([dx, dy])
         return distance < (player.radius)
 
-    def add_player(self, player_name: int | str, pos: npt.NDArray = None):
+    def add_player(self, player_name: str, pos: npt.NDArray = None):
         log.debug(f"Add player {player_name} to the game")
         player = Player(
             player_name, player_config=self.environment_config["player_config"], pos=pos
@@ -618,7 +618,7 @@ class Environment:
             "remaining_time": max(self.env_time_end - self.env_time, timedelta(0)),
         }
 
-    def get_json_state(self, player_id: str | int = None):
+    def get_json_state(self, player_id: str = None):
         state = {
             "players": [p.to_dict() for p in self.players.values()],
             "counters": [c.to_dict() for c in self.counters],
diff --git a/overcooked_simulator/player.py b/overcooked_simulator/player.py
index 4283e76f..e75be946 100644
--- a/overcooked_simulator/player.py
+++ b/overcooked_simulator/player.py
@@ -8,6 +8,7 @@ import numpy.typing as npt
 
 from overcooked_simulator.counters import Counter
 from overcooked_simulator.game_items import Item, Plate
+from overcooked_simulator.state_representation import PlayerState
 
 log = logging.getLogger(__name__)
 
@@ -21,11 +22,11 @@ class Player:
 
     def __init__(
         self,
-        name: int | str,
+        name: str,
         player_config: dict[str, Any],
         pos: Optional[npt.NDArray[float]] = None,
     ):
-        self.name: int | str = name
+        self.name: str = name
         self.player_config = player_config
         if pos is not None:
             self.pos: npt.NDArray[float] = np.array(pos, dtype=float)
@@ -145,7 +146,7 @@ class Player:
     def __repr__(self):
         return f"Player(name:{self.name},pos:{str(self.pos)},holds:{self.holding})"
 
-    def to_dict(self):
+    def to_dict(self) -> PlayerState:
         # TODO add color to player class for vis independent player color
         return {
             "id": self.name,
-- 
GitLab