From 3bf758346cd74a9c17cacd14a9d5f2f7bbf66337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Schr=C3=B6der?= <fschroeder@techfak.uni-bielefeld.de> Date: Tue, 20 Aug 2024 15:58:10 +0200 Subject: [PATCH] Refactor gym_env and drawing for environment-specific caching Updated `gym_env.py` and `drawing.py` to use unique environment names (`env_name`) for caching purposes, leveraging UUIDs for better cache management. Revised render and image fetch methods to incorporate the unique environment identifier, ensuring accurate and efficient cache invalidation and retrieval. --- cooperative_cuisine/pygame_2d_vis/drawing.py | 46 +++++++++++-------- .../reinforcement_learning/gym_env.py | 15 +++--- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/cooperative_cuisine/pygame_2d_vis/drawing.py b/cooperative_cuisine/pygame_2d_vis/drawing.py index c98ce4b2..eb4869fd 100644 --- a/cooperative_cuisine/pygame_2d_vis/drawing.py +++ b/cooperative_cuisine/pygame_2d_vis/drawing.py @@ -125,7 +125,8 @@ class Visualizer: self.cache_flags = reduce(lambda s, x: s | x, configs_cache_flags, CacheFlags.NONE) self.grid_size_lock = Lock() - self.reduced_background = config["GameWindow"]["reduced_background"] if "GameWindow" in config and "reduced_background" in config["GameWindow"] else True + self.reduced_background = config["GameWindow"][ + "reduced_background"] if "GameWindow" in config and "reduced_background" in config["GameWindow"] else True def invalidate_surface_cache(self): self.surface_cache_dict = {} @@ -162,6 +163,7 @@ class Visualizer: self, state: dict, controlled_player_idxs: list[int], + env_id_ref=None, ): """Draws the game state on the given surface. @@ -174,8 +176,9 @@ class Visualizer: width = int(np.ceil(state["kitchen"]["width"] * self.grid_size)) height = int(np.ceil(state["kitchen"]["height"] * self.grid_size)) screen = pygame.Surface((width, height), pygame.SRCALPHA) - - if "counters+background" not in self.surface_cache_dict: + if env_id_ref in self.surface_cache_dict: + screen.blit(self.surface_cache_dict[env_id_ref], (0, 0)) + else: if CacheFlags.BACKGROUND in self.cache_flags: self.draw_background( surface=screen, @@ -188,9 +191,7 @@ class Visualizer: screen, state["counters"], ) - self.surface_cache_dict["counters+background"] = screen.copy() - else: - screen.blit(self.surface_cache_dict["counters+background"], (0, 0)) + self.surface_cache_dict[env_id_ref] = screen.copy() if CacheFlags.BACKGROUND not in self.cache_flags: self.draw_background( @@ -216,7 +217,7 @@ class Visualizer: col, (np.array(state["players"][int(idx)]["pos"]) + 0.5) * self.grid_size, (self.grid_size / 2), - ) + ) self.draw_players( screen, @@ -277,7 +278,7 @@ class Visualizer: + (direction.rotate(-90) * rect_scale), right_beam + (direction.rotate(-90) * rect_scale), right_beam - offset_front, - ] + ] light_cone_points = [pos - offset_front, left_beam, right_beam] pygame.draw.polygon( cone_mask, @@ -311,7 +312,7 @@ class Visualizer: pos - (direction * rect_scale) + (direction.rotate(-90) * rect_scale), - ] + ] pygame.draw.polygon(cone_mask, mask_color, corners) @@ -360,7 +361,6 @@ class Visualizer: 1, ) - def draw_image( self, screen: pygame.Surface, @@ -418,7 +418,7 @@ class Visualizer: color, self.model_to_world_coords(pos - facing * 0.25), self.grid_size * 0.2, - ) + ) self.draw_thing(screen, pos, self.config["Cook"]["parts"], scale=1.0, orientation=facing.tolist()) def draw_players( @@ -498,7 +498,7 @@ class Visualizer: pos + (facing * self.grid_size * 0.4), 1.6 * self.grid_size, width=1, - ) + ) pygame.draw.circle( screen, colors["red1"], pos + (facing * self.grid_size * 0.4), 4 ) @@ -582,7 +582,7 @@ class Visualizer: draw_pos[1] - (width / 2), height, width, - ) + ) pygame.draw.rect(screen, color, rect) case "circle": @@ -723,7 +723,7 @@ class Visualizer: bar_pos[1] + size - bar_height, progress_width, bar_height, - ) + ) pygame.draw.rect(screen, colors["red" if attention else "green1"], progress_bar) def draw_counter( @@ -904,7 +904,7 @@ class Visualizer: order_upper_left = [ order_rects_start + idx * grid_size * 1.2, order_rects_start, - ] + ] pygame.draw.rect( screen, colors["red"], @@ -958,24 +958,30 @@ class Visualizer: def get_state_image(self, state: dict, controlled_players: list[int] = None, grid_size: int | None = None, + env_id_ref=None, ) -> npt.NDArray[np.uint8]: if grid_size is None: return pygame.surfarray.pixels3d( - self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players)).transpose((1, 0, 2)) + self.draw_gamescreen(state, + [0] if controlled_players is None else controlled_players, + env_id_ref=env_id_ref + )).transpose((1, 0, 2)) with self.grid_size_lock: pre_grid_size = self.grid_size try: self.set_grid_size(grid_size) - screen = self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players) + screen = self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players, + env_id_ref=env_id_ref) finally: self.set_grid_size(pre_grid_size) return pygame.surfarray.pixels3d(screen).transpose((1, 0, 2)) - def get_state_image_by_size_direct(self, state: dict, + def get_state_image_by_size(self, state: dict, max_size: int, - controlled_players: list[int] = None): + controlled_players: list[int] = None, + env_id_ref=None): grid_size = max_size / max(state["kitchen"]["width"], state["kitchen"]["height"]) - image = self.get_state_image(state, controlled_players, grid_size) + image = self.get_state_image(state, controlled_players, grid_size, env_id_ref=env_id_ref) if state["kitchen"]["width"] == state["kitchen"]["height"]: return image squared = np.zeros((max_size, max_size, 3), dtype=np.uint8) diff --git a/cooperative_cuisine/reinforcement_learning/gym_env.py b/cooperative_cuisine/reinforcement_learning/gym_env.py index 74245dd2..5caef787 100644 --- a/cooperative_cuisine/reinforcement_learning/gym_env.py +++ b/cooperative_cuisine/reinforcement_learning/gym_env.py @@ -1,10 +1,12 @@ import json import random import time +import uuid from collections import deque from datetime import timedelta from enum import Enum from pathlib import Path +from uuid import uuid4 import cv2 import numpy as np @@ -144,6 +146,7 @@ class EnvGymWrapper(Env): layout_config=layout, item_info=item_info, as_files=False, + env_name=uuid.uuid4().hex, ) if self.randomize_counter_placement: @@ -377,11 +380,13 @@ class EnvGymWrapper(Env): return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): + del visualizer.surface_cache_dict[self.env.env_name] self.env: Environment = Environment( env_config=environment_config, layout_config=layout, item_info=item_info, as_files=False, + env_name=uuid.uuid4().hex, ) if self.randomize_counter_placement: @@ -406,11 +411,7 @@ class EnvGymWrapper(Env): return obs def render(self): - observation = self.get_env_img() - img = observation.astype(np.uint8) - img = img.transpose((1, 2, 0)) - img = cv2.resize(img, (img.shape[1], img.shape[0])) - return img + return self.get_env_img() def close(self): pass @@ -418,8 +419,8 @@ class EnvGymWrapper(Env): def get_env_img(self): state = self.env.get_json_state(player_id=self.player_id) json_dict = json.loads(state) - observation = self.visualizer.get_state_image(state=json_dict).transpose((1, 0, 2)) - return (observation.transpose((2, 0, 1))).astype(np.uint8) + observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8) + return observation def get_vector_state(self): obs = self.get_vectorized_state_simple("0", self.onehot_state) -- GitLab