Skip to content
Snippets Groups Projects
Commit 3bf75834 authored by Florian Schröder's avatar Florian Schröder
Browse files

Refactor gym_env and drawing for environment-specific caching

Updated `gym_env.py` and `drawing.py` to use unique environment names (`env_name`) for caching purposes, leveraging UUIDs for better cache management. Revised render and image fetch methods to incorporate the unique environment identifier, ensuring accurate and efficient cache invalidation and retrieval.
parent 461bf384
No related branches found
No related tags found
2 merge requests!110V1.2.0 changes,!104Resolve "Faster Drawing"
Pipeline #59076 passed
......@@ -125,7 +125,8 @@ class Visualizer:
self.cache_flags = reduce(lambda s, x: s | x, configs_cache_flags, CacheFlags.NONE)
self.grid_size_lock = Lock()
self.reduced_background = config["GameWindow"]["reduced_background"] if "GameWindow" in config and "reduced_background" in config["GameWindow"] else True
self.reduced_background = config["GameWindow"][
"reduced_background"] if "GameWindow" in config and "reduced_background" in config["GameWindow"] else True
def invalidate_surface_cache(self):
self.surface_cache_dict = {}
......@@ -162,6 +163,7 @@ class Visualizer:
self,
state: dict,
controlled_player_idxs: list[int],
env_id_ref=None,
):
"""Draws the game state on the given surface.
......@@ -174,8 +176,9 @@ class Visualizer:
width = int(np.ceil(state["kitchen"]["width"] * self.grid_size))
height = int(np.ceil(state["kitchen"]["height"] * self.grid_size))
screen = pygame.Surface((width, height), pygame.SRCALPHA)
if "counters+background" not in self.surface_cache_dict:
if env_id_ref in self.surface_cache_dict:
screen.blit(self.surface_cache_dict[env_id_ref], (0, 0))
else:
if CacheFlags.BACKGROUND in self.cache_flags:
self.draw_background(
surface=screen,
......@@ -188,9 +191,7 @@ class Visualizer:
screen,
state["counters"],
)
self.surface_cache_dict["counters+background"] = screen.copy()
else:
screen.blit(self.surface_cache_dict["counters+background"], (0, 0))
self.surface_cache_dict[env_id_ref] = screen.copy()
if CacheFlags.BACKGROUND not in self.cache_flags:
self.draw_background(
......@@ -216,7 +217,7 @@ class Visualizer:
col,
(np.array(state["players"][int(idx)]["pos"]) + 0.5) * self.grid_size,
(self.grid_size / 2),
)
)
self.draw_players(
screen,
......@@ -277,7 +278,7 @@ class Visualizer:
+ (direction.rotate(-90) * rect_scale),
right_beam + (direction.rotate(-90) * rect_scale),
right_beam - offset_front,
]
]
light_cone_points = [pos - offset_front, left_beam, right_beam]
pygame.draw.polygon(
cone_mask,
......@@ -311,7 +312,7 @@ class Visualizer:
pos
- (direction * rect_scale)
+ (direction.rotate(-90) * rect_scale),
]
]
pygame.draw.polygon(cone_mask, mask_color, corners)
......@@ -360,7 +361,6 @@ class Visualizer:
1,
)
def draw_image(
self,
screen: pygame.Surface,
......@@ -418,7 +418,7 @@ class Visualizer:
color,
self.model_to_world_coords(pos - facing * 0.25),
self.grid_size * 0.2,
)
)
self.draw_thing(screen, pos, self.config["Cook"]["parts"], scale=1.0, orientation=facing.tolist())
def draw_players(
......@@ -498,7 +498,7 @@ class Visualizer:
pos + (facing * self.grid_size * 0.4),
1.6 * self.grid_size,
width=1,
)
)
pygame.draw.circle(
screen, colors["red1"], pos + (facing * self.grid_size * 0.4), 4
)
......@@ -582,7 +582,7 @@ class Visualizer:
draw_pos[1] - (width / 2),
height,
width,
)
)
pygame.draw.rect(screen, color, rect)
case "circle":
......@@ -723,7 +723,7 @@ class Visualizer:
bar_pos[1] + size - bar_height,
progress_width,
bar_height,
)
)
pygame.draw.rect(screen, colors["red" if attention else "green1"], progress_bar)
def draw_counter(
......@@ -904,7 +904,7 @@ class Visualizer:
order_upper_left = [
order_rects_start + idx * grid_size * 1.2,
order_rects_start,
]
]
pygame.draw.rect(
screen,
colors["red"],
......@@ -958,24 +958,30 @@ class Visualizer:
def get_state_image(self, state: dict,
controlled_players: list[int] = None,
grid_size: int | None = None,
env_id_ref=None,
) -> npt.NDArray[np.uint8]:
if grid_size is None:
return pygame.surfarray.pixels3d(
self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players)).transpose((1, 0, 2))
self.draw_gamescreen(state,
[0] if controlled_players is None else controlled_players,
env_id_ref=env_id_ref
)).transpose((1, 0, 2))
with self.grid_size_lock:
pre_grid_size = self.grid_size
try:
self.set_grid_size(grid_size)
screen = self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players)
screen = self.draw_gamescreen(state, [0] if controlled_players is None else controlled_players,
env_id_ref=env_id_ref)
finally:
self.set_grid_size(pre_grid_size)
return pygame.surfarray.pixels3d(screen).transpose((1, 0, 2))
def get_state_image_by_size_direct(self, state: dict,
def get_state_image_by_size(self, state: dict,
max_size: int,
controlled_players: list[int] = None):
controlled_players: list[int] = None,
env_id_ref=None):
grid_size = max_size / max(state["kitchen"]["width"], state["kitchen"]["height"])
image = self.get_state_image(state, controlled_players, grid_size)
image = self.get_state_image(state, controlled_players, grid_size, env_id_ref=env_id_ref)
if state["kitchen"]["width"] == state["kitchen"]["height"]:
return image
squared = np.zeros((max_size, max_size, 3), dtype=np.uint8)
......
import json
import random
import time
import uuid
from collections import deque
from datetime import timedelta
from enum import Enum
from pathlib import Path
from uuid import uuid4
import cv2
import numpy as np
......@@ -144,6 +146,7 @@ class EnvGymWrapper(Env):
layout_config=layout,
item_info=item_info,
as_files=False,
env_name=uuid.uuid4().hex,
)
if self.randomize_counter_placement:
......@@ -377,11 +380,13 @@ class EnvGymWrapper(Env):
return observation, reward, terminated, truncated, info
def reset(self, seed=None, options=None):
del visualizer.surface_cache_dict[self.env.env_name]
self.env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False,
env_name=uuid.uuid4().hex,
)
if self.randomize_counter_placement:
......@@ -406,11 +411,7 @@ class EnvGymWrapper(Env):
return obs
def render(self):
observation = self.get_env_img()
img = observation.astype(np.uint8)
img = img.transpose((1, 2, 0))
img = cv2.resize(img, (img.shape[1], img.shape[0]))
return img
return self.get_env_img()
def close(self):
pass
......@@ -418,8 +419,8 @@ class EnvGymWrapper(Env):
def get_env_img(self):
state = self.env.get_json_state(player_id=self.player_id)
json_dict = json.loads(state)
observation = self.visualizer.get_state_image(state=json_dict).transpose((1, 0, 2))
return (observation.transpose((2, 0, 1))).astype(np.uint8)
observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8)
return observation
def get_vector_state(self):
obs = self.get_vectorized_state_simple("0", self.onehot_state)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment