From 3bf758346cd74a9c17cacd14a9d5f2f7bbf66337 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Florian=20Schr=C3=B6der?=
 <fschroeder@techfak.uni-bielefeld.de>
Date: Tue, 20 Aug 2024 15:58:10 +0200
Subject: [PATCH] Refactor gym_env and drawing for environment-specific caching

Updated `gym_env.py` and `drawing.py` to use unique environment names (`env_name`) for caching purposes, leveraging UUIDs for better cache management. Revised render and image fetch methods to incorporate the unique environment identifier, ensuring accurate and efficient cache invalidation and retrieval.
---
 cooperative_cuisine/pygame_2d_vis/drawing.py  | 46 +++++++++++--------
 .../reinforcement_learning/gym_env.py         | 15 +++---
 2 files changed, 34 insertions(+), 27 deletions(-)

diff --git a/cooperative_cuisine/pygame_2d_vis/drawing.py b/cooperative_cuisine/pygame_2d_vis/drawing.py
index c98ce4b2..eb4869fd 100644
--- a/cooperative_cuisine/pygame_2d_vis/drawing.py
+++ b/cooperative_cuisine/pygame_2d_vis/drawing.py
@@ -125,7 +125,8 @@ class Visualizer:
         self.cache_flags = reduce(lambda s, x: s | x, configs_cache_flags, CacheFlags.NONE)
 
         self.grid_size_lock = Lock()
-        self.reduced_background = config["GameWindow"]["reduced_background"] if "GameWindow" in config and "reduced_background" in config["GameWindow"] else True
+        self.reduced_background = config["GameWindow"][
+            "reduced_background"] if "GameWindow" in config and "reduced_background" in config["GameWindow"] else True
 
     def invalidate_surface_cache(self):
         self.surface_cache_dict = {}
@@ -162,6 +163,7 @@ class Visualizer:
         self,
         state: dict,
         controlled_player_idxs: list[int],
+        env_id_ref=None,
     ):
         """Draws the game state on the given surface.
 
@@ -174,8 +176,9 @@ class Visualizer:
         width = int(np.ceil(state["kitchen"]["width"] * self.grid_size))
         height = int(np.ceil(state["kitchen"]["height"] * self.grid_size))
         screen = pygame.Surface((width, height), pygame.SRCALPHA)
-
-        if "counters+background" not in self.surface_cache_dict:
+        if env_id_ref in self.surface_cache_dict:
+            screen.blit(self.surface_cache_dict[env_id_ref], (0, 0))
+        else:
             if CacheFlags.BACKGROUND in self.cache_flags:
                 self.draw_background(
                     surface=screen,
@@ -188,9 +191,7 @@ class Visualizer:
                     screen,
                     state["counters"],
                 )
-            self.surface_cache_dict["counters+background"] = screen.copy()
-        else:
-            screen.blit(self.surface_cache_dict["counters+background"], (0, 0))
+            self.surface_cache_dict[env_id_ref] = screen.copy()
 
         if CacheFlags.BACKGROUND not in self.cache_flags:
             self.draw_background(
@@ -216,7 +217,7 @@ class Visualizer:
                 col,
                 (np.array(state["players"][int(idx)]["pos"]) + 0.5) * self.grid_size,
                 (self.grid_size / 2),
-                )
+            )
 
         self.draw_players(
             screen,
@@ -277,7 +278,7 @@ class Visualizer:
                     + (direction.rotate(-90) * rect_scale),
                     right_beam + (direction.rotate(-90) * rect_scale),
                     right_beam - offset_front,
-                    ]
+                ]
                 light_cone_points = [pos - offset_front, left_beam, right_beam]
                 pygame.draw.polygon(
                     cone_mask,
@@ -311,7 +312,7 @@ class Visualizer:
                     pos
                     - (direction * rect_scale)
                     + (direction.rotate(-90) * rect_scale),
-                    ]
+                ]
 
                 pygame.draw.polygon(cone_mask, mask_color, corners)
 
@@ -360,7 +361,6 @@ class Visualizer:
                         1,
                     )
 
-
     def draw_image(
         self,
         screen: pygame.Surface,
@@ -418,7 +418,7 @@ class Visualizer:
             color,
             self.model_to_world_coords(pos - facing * 0.25),
             self.grid_size * 0.2,
-            )
+        )
         self.draw_thing(screen, pos, self.config["Cook"]["parts"], scale=1.0, orientation=facing.tolist())
 
     def draw_players(
@@ -498,7 +498,7 @@ class Visualizer:
                     pos + (facing * self.grid_size * 0.4),
                     1.6 * self.grid_size,
                     width=1,
-                    )
+                )
                 pygame.draw.circle(
                     screen, colors["red1"], pos + (facing * self.grid_size * 0.4), 4
                 )
@@ -582,7 +582,7 @@ class Visualizer:
                         draw_pos[1] - (width / 2),
                         height,
                         width,
-                        )
+                    )
                     pygame.draw.rect(screen, color, rect)
 
                 case "circle":
@@ -723,7 +723,7 @@ class Visualizer:
             bar_pos[1] + size - bar_height,
             progress_width,
             bar_height,
-            )
+        )
         pygame.draw.rect(screen, colors["red" if attention else "green1"], progress_bar)
 
     def draw_counter(
@@ -904,7 +904,7 @@ class Visualizer:
             order_upper_left = [
                 order_rects_start + idx * grid_size * 1.2,
                 order_rects_start,
-                ]
+            ]
             pygame.draw.rect(
                 screen,
                 colors["red"],
@@ -958,24 +958,30 @@ class Visualizer:
     def get_state_image(self, state: dict,
                         controlled_players: list[int] = None,
                         grid_size: int | None = None,
+                        env_id_ref=None,
                         ) -> npt.NDArray[np.uint8]:
         if grid_size is None:
             return pygame.surfarray.pixels3d(
-                self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players)).transpose((1, 0, 2))
+                self.draw_gamescreen(state,
+                                     [0] if controlled_players is None else controlled_players,
+                                     env_id_ref=env_id_ref
+                                     )).transpose((1, 0, 2))
         with self.grid_size_lock:
             pre_grid_size = self.grid_size
             try:
                 self.set_grid_size(grid_size)
-                screen = self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players)
+                screen = self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players,
+                                              env_id_ref=env_id_ref)
             finally:
                 self.set_grid_size(pre_grid_size)
             return pygame.surfarray.pixels3d(screen).transpose((1, 0, 2))
 
-    def get_state_image_by_size_direct(self, state: dict,
+    def get_state_image_by_size(self, state: dict,
                                 max_size: int,
-                                controlled_players: list[int] = None):
+                                controlled_players: list[int] = None,
+                                env_id_ref=None):
         grid_size = max_size / max(state["kitchen"]["width"], state["kitchen"]["height"])
-        image =  self.get_state_image(state, controlled_players, grid_size)
+        image = self.get_state_image(state, controlled_players, grid_size, env_id_ref=env_id_ref)
         if state["kitchen"]["width"] == state["kitchen"]["height"]:
             return image
         squared = np.zeros((max_size, max_size, 3), dtype=np.uint8)
diff --git a/cooperative_cuisine/reinforcement_learning/gym_env.py b/cooperative_cuisine/reinforcement_learning/gym_env.py
index 74245dd2..5caef787 100644
--- a/cooperative_cuisine/reinforcement_learning/gym_env.py
+++ b/cooperative_cuisine/reinforcement_learning/gym_env.py
@@ -1,10 +1,12 @@
 import json
 import random
 import time
+import uuid
 from collections import deque
 from datetime import timedelta
 from enum import Enum
 from pathlib import Path
+from uuid import uuid4
 
 import cv2
 import numpy as np
@@ -144,6 +146,7 @@ class EnvGymWrapper(Env):
             layout_config=layout,
             item_info=item_info,
             as_files=False,
+            env_name=uuid.uuid4().hex,
         )
 
         if self.randomize_counter_placement:
@@ -377,11 +380,13 @@ class EnvGymWrapper(Env):
         return observation, reward, terminated, truncated, info
 
     def reset(self, seed=None, options=None):
+        del visualizer.surface_cache_dict[self.env.env_name]
         self.env: Environment = Environment(
             env_config=environment_config,
             layout_config=layout,
             item_info=item_info,
             as_files=False,
+            env_name=uuid.uuid4().hex,
         )
 
         if self.randomize_counter_placement:
@@ -406,11 +411,7 @@ class EnvGymWrapper(Env):
         return obs
 
     def render(self):
-        observation = self.get_env_img()
-        img = observation.astype(np.uint8)
-        img = img.transpose((1, 2, 0))
-        img = cv2.resize(img, (img.shape[1], img.shape[0]))
-        return img
+        return self.get_env_img()
 
     def close(self):
         pass
@@ -418,8 +419,8 @@ class EnvGymWrapper(Env):
     def get_env_img(self):
         state = self.env.get_json_state(player_id=self.player_id)
         json_dict = json.loads(state)
-        observation = self.visualizer.get_state_image(state=json_dict).transpose((1, 0, 2))
-        return (observation.transpose((2, 0, 1))).astype(np.uint8)
+        observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8)
+        return observation
 
     def get_vector_state(self):
         obs = self.get_vectorized_state_simple("0", self.onehot_state)
-- 
GitLab