From 481c52d936fa44aaf4c6bba4145399a8383b08b6 Mon Sep 17 00:00:00 2001
From: fheinrich <fheinrich@techfak.uni-bielefeld.de>
Date: Mon, 18 Dec 2023 13:36:14 +0100
Subject: [PATCH] Pygame images are stored in a dict to avoid loading multiple
 times

---
 overcooked_simulator/pygame_gui/pygame_gui.py | 47 ++++++++++---------
 1 file changed, 25 insertions(+), 22 deletions(-)

diff --git a/overcooked_simulator/pygame_gui/pygame_gui.py b/overcooked_simulator/pygame_gui/pygame_gui.py
index 06310948..3f41cb3d 100644
--- a/overcooked_simulator/pygame_gui/pygame_gui.py
+++ b/overcooked_simulator/pygame_gui/pygame_gui.py
@@ -19,7 +19,7 @@ from overcooked_simulator.pygame_gui.game_colors import BLUE
 from overcooked_simulator.pygame_gui.game_colors import colors, Color
 from overcooked_simulator.simulation_runner import Simulator
 
-PLAYER_DEBUG_VIZ = True
+USE_PLAYER_COOK_SPRITES = False
 SHOW_INTERACTION_RANGE = False
 
 
@@ -104,6 +104,8 @@ class PyGameGUI:
 
         self.player_colors = self.create_player_colors()
 
+        self.image_cache_dict = {}
+
     def create_player_colors(self) -> list[Color]:
         number_player = len(self.simulator.env.players)
         hue_values = np.linspace(0, 1, number_player + 1)
@@ -183,6 +185,24 @@ class PyGameGUI:
                     1,
                 )
 
+    def draw_image(self, img_path, size, pos, rot_angle=0):
+        cache_entry = f"{img_path}_{size}_{rot_angle}"
+        if cache_entry in self.image_cache_dict.keys():
+            image = self.image_cache_dict[cache_entry]
+        else:
+            image = pygame.image.load(
+                ROOT_DIR / "pygame_gui" / img_path
+            ).convert_alpha()
+            image = pygame.transform.scale(image, (size, size))
+            if rot_angle != 0:
+                image = pygame.transform.rotate(image, rot_angle)
+
+            self.image_cache_dict[cache_entry] = image
+
+        rect = image.get_rect()
+        rect.center = pos
+        self.screen.blit(image, rect)
+
     def draw_players(self, state):
         """Visualizes the players as circles with a triangle for the facing direction.
         If the player holds something in their hands, it is displayed
@@ -191,7 +211,7 @@ class PyGameGUI:
             state: The game state returned by the environment.
         """
         for p_idx, player in enumerate(state["players"].values()):
-            if PLAYER_DEBUG_VIZ:
+            if USE_PLAYER_COOK_SPRITES:
                 pos = player.pos
                 size = player.radius
                 color1 = self.player_colors[p_idx]
@@ -214,18 +234,10 @@ class PyGameGUI:
                 )
             else:
                 img_path = self.visualization_config["Cook"]["parts"][0]["path"]
-                image = pygame.image.load(
-                    ROOT_DIR / "pygame_gui" / img_path
-                ).convert_alpha()
                 rel_x, rel_y = player.facing_direction
                 angle = -np.rad2deg(math.atan2(rel_y, rel_x)) + 90
-
                 size = self.visualization_config["Cook"]["parts"][0]["size"]
-                image = pygame.transform.scale(image, (size, size))
-                image = pygame.transform.rotate(image, angle)
-                rect = image.get_rect()
-                rect.center = player.pos
-                self.screen.blit(image, rect)
+                self.draw_image(img_path, size, player.pos, angle)
 
             if SHOW_INTERACTION_RANGE:
                 pygame.draw.circle(
@@ -246,16 +258,7 @@ class PyGameGUI:
         for part in parts:
             part_type = part["type"]
             if part_type == "image":
-                image = pygame.image.load(
-                    ROOT_DIR / "pygame_gui" / parts[0]["path"]
-                ).convert_alpha()
-
-                size = parts[0]["size"]
-                image = pygame.transform.scale(image, (size * scale, size * scale))
-
-                rect = image.get_rect()
-                rect.center = pos
-                self.screen.blit(image, rect)
+                self.draw_image(parts[0]["path"], parts[0]["size"], pos)
             elif part_type == "rect":
                 height = part["height"]
                 width = part["width"]
@@ -285,7 +288,7 @@ class PyGameGUI:
                 else:
                     pygame.draw.circle(self.screen, color, pos, radius)
 
-    def draw_item(self, pos: npt.NDArray[float], item: Item, scale=1.0):
+    def draw_item(self, pos: npt.NDArray[float], item: Item, scale: float = 1.0):
         """Visualization of an item at the specified position. On a counter or in the hands of the player.
         The visual composition of the item is read in from visualization.yaml file, where it is specified as
         different parts to be drawn.
-- 
GitLab