From 47cc026f5f513f9aff8a3dc97cab15ff6762b709 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Florian=20Schr=C3=B6der?= <fschroeder@techfak.uni-bielefeld.de> Date: Thu, 8 Feb 2024 19:05:35 +0100 Subject: [PATCH] Implement vector state representation in game environment The game environment now supports vectorized states for reinforcement learning agents. This includes updates on various classes such as OrderManager, Counter, and Player. Several new utility functions are added to facilitate the conversion of the game state to vector form. The new vector state includes players, counters, orders, and game status. --- overcooked_simulator/game_server.py | 7 + overcooked_simulator/gym_env.py | 100 ++++--- .../overcooked_environment.py | 267 +++++++++++++++++- overcooked_simulator/utils.py | 40 +++ 4 files changed, 359 insertions(+), 55 deletions(-) diff --git a/overcooked_simulator/game_server.py b/overcooked_simulator/game_server.py index c5bb5d25..1db99598 100644 --- a/overcooked_simulator/game_server.py +++ b/overcooked_simulator/game_server.py @@ -432,6 +432,13 @@ class EnvironmentHandler: / 1_000_000_000 ) ) + + ( + grid, + player, + env_time, + orders, + ) = env_data.environment.get_vectorized_state("0") env_data.last_step_time = step_start if env_data.environment.game_ended: log.info(f"Env {env_id} ended. Set env to STOPPED.") diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py index 9a0c8094..26d5da99 100644 --- a/overcooked_simulator/gym_env.py +++ b/overcooked_simulator/gym_env.py @@ -1,5 +1,4 @@ import json -import os.path import time from datetime import timedelta from enum import Enum @@ -8,6 +7,12 @@ from pathlib import Path import cv2 import numpy as np import yaml +from gymnasium import spaces, Env +from stable_baselines3 import A2C +from stable_baselines3 import DQN +from stable_baselines3 import PPO +from stable_baselines3.common.env_checker import check_env +from stable_baselines3.common.env_util import make_vec_env from overcooked_simulator import ROOT_DIR from overcooked_simulator.gui_2d_vis.drawing import Visualizer @@ -17,29 +22,22 @@ from overcooked_simulator.overcooked_environment import ( ActionType, InterActionData, ) -import wandb -from wandb.integration.sb3 import WandbCallback - -import gymnasium as gym -import numpy as np -from gymnasium import spaces, Env - -from stable_baselines3.common.env_checker import check_env -from stable_baselines3.common.env_util import make_vec_env -from stable_baselines3 import A2C -from stable_baselines3 import DQN -from stable_baselines3 import PPO -SimpleActionSpace = Enum("SimpleActionSpace", ["Up", - # "Up_Left", - "Left", - # "Down_Left", - "Down", - # "Down_Right", - "Right", - # "Right_Up", - "Interact", - "Put"]) +SimpleActionSpace = Enum( + "SimpleActionSpace", + [ + "Up", + # "Up_Left", + "Left", + # "Down_Left", + "Down", + # "Down_Right", + "Right", + # "Right_Up", + "Interact", + "Put", + ], +) def get_env_action(player_id, simple_action, duration): @@ -118,9 +116,7 @@ def get_env_action(player_id, simple_action, duration): print("FAIL", simple_action) -environment_config_path: Path = ( - ROOT_DIR / "game_content" / "environment_config_rl.yaml" -) +environment_config_path: Path = ROOT_DIR / "game_content" / "environment_config_rl.yaml" item_info_path: Path = ROOT_DIR / "game_content" / "item_info_rl.yaml" layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout" with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file: @@ -131,6 +127,7 @@ class EnvGymWrapper(Env): """Should enable this: observation, reward, terminated, truncated, info = env.step(action) """ + metadata = {"render_modes": ["human"], "render_fps": 30} def __init__(self): @@ -142,7 +139,7 @@ class EnvGymWrapper(Env): env_config=environment_config_path, layout_config=layout_path, item_info=item_info_path, - as_files=True + as_files=True, ) self.visualizer: Visualizer = Visualizer(config=visualization_config) @@ -153,11 +150,10 @@ class EnvGymWrapper(Env): self.visualizer.create_player_colors(1) # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} - self.action_space_map ={} + self.action_space_map = {} for idx, item in enumerate(SimpleActionSpace): self.action_space_map[idx] = item - self.global_step_time = 1 self.in_between_steps = 1 @@ -165,8 +161,9 @@ class EnvGymWrapper(Env): # Example for using image as input (channel-first; channel-last also works): dummy_obs = self.get_env_img(self.gridsize) - self.observation_space = spaces.Box(low=0, high=255, - shape=dummy_obs.shape, dtype=np.uint8) + self.observation_space = spaces.Box( + low=0, high=255, shape=dummy_obs.shape, dtype=np.uint8 + ) self.last_obs = dummy_obs @@ -175,7 +172,9 @@ class EnvGymWrapper(Env): def step(self, action): simple_action = self.action_space_map[action] - env_action = get_env_action(self.player_id, simple_action, self.global_step_time) + env_action = get_env_action( + self.player_id, simple_action, self.global_step_time + ) self.env.perform_action(env_action) for i in range(self.in_between_steps): @@ -186,10 +185,13 @@ class EnvGymWrapper(Env): observation = self.get_env_img(self.gridsize) reward = -1 - if self.env.order_and_score.score > self.prev_score and self.env.order_and_score.score != 0: - self.prev_score = self.env.order_and_score.score + if ( + self.env.order_manager.score > self.prev_score + and self.env.order_manager.score != 0 + ): + self.prev_score = self.env.order_manager.score reward = 100 - elif self.env.order_and_score.score < self.prev_score: + elif self.env.order_manager.score < self.prev_score: self.prev_score = 0 reward = -1 @@ -200,14 +202,12 @@ class EnvGymWrapper(Env): # self.render(self.gridsize) return observation, reward, terminated, truncated, info - def reset(self, seed=None, options=None): - self.env: Environment = Environment( env_config=environment_config_path, layout_config=layout_path, item_info=item_info_path, - as_files=True + as_files=True, ) self.player_name = str(0) @@ -219,10 +219,10 @@ class EnvGymWrapper(Env): def render(self): observation = self.get_env_img(self.gridsize) - img = observation.transpose((1,2,0))[:,:,::-1] + img = observation.transpose((1, 2, 0))[:, :, ::-1] print(img.shape) - img = cv2.resize(img, (img.shape[1]*5, img.shape[0]*5)) - cv2.imshow("Overcooked",img) + img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5)) + cv2.imshow("Overcooked", img) cv2.waitKey(1) def close(self): @@ -234,7 +234,13 @@ class EnvGymWrapper(Env): observation = self.visualizer.get_state_image( grid_size=gridsize, state=json_dict ).transpose((1, 0, 2)) - return observation.transpose((2,0,1)) + return observation.transpose((2, 0, 1)) + + def get_vector_state(self): + grid, player, env_time, orders = self.env.get_vectorized_state("0") + + # flatten: grid + player + # concatenate all (env_time to array) def sample_random_action(self): act = self.action_space.sample() @@ -246,7 +252,6 @@ def main(): rl_agent_checkpoints = Path("./rl_agent_checkpoints") rl_agent_checkpoints.mkdir(exist_ok=True) - config = { "policy_type": "CnnPolicy", "total_timesteps": 1000000, # hendric sagt eher so 300_000_000 schritte @@ -266,8 +271,9 @@ def main(): model_class = model_classes[2] # model = model_class(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{run.id}") - model = model_class(config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}") - + model = model_class( + config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}" + ) model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" # if os.path.exists(model_save_path): @@ -280,7 +286,7 @@ def main(): # verbose=0, # ), log_interval=1, - progress_bar=True + progress_bar=True, ) # run.finish() model.save(model_save_path) @@ -293,7 +299,7 @@ def main(): check_env(env) obs, info = env.reset() while True: - time.sleep(1/30) + time.sleep(1 / 30) action, _states = model.predict(obs, deterministic=False) obs, reward, terminated, truncated, info = env.step(int(action)) env.render() diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 6d8dd883..00a5794c 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -5,6 +5,7 @@ import inspect import json import logging import sys +from collections import deque from datetime import timedelta, datetime from enum import Enum from pathlib import Path @@ -20,11 +21,15 @@ from overcooked_simulator.counter_factory import CounterFactory from overcooked_simulator.counters import ( Counter, PlateConfig, + CookingCounter, + Dispenser, ) from overcooked_simulator.effect_manager import EffectManager from overcooked_simulator.game_items import ( ItemInfo, ItemType, + CookingEquipment, + Item, ) from overcooked_simulator.hooks import ( ITEM_INFO_LOADED, @@ -50,7 +55,11 @@ from overcooked_simulator.order import ( OrderConfig, ) from overcooked_simulator.player import Player, PlayerConfig -from overcooked_simulator.utils import create_init_env_time, get_closest +from overcooked_simulator.utils import ( + create_init_env_time, + get_closest, + VectorStateGenerationData, +) log = logging.getLogger(__name__) @@ -194,7 +203,7 @@ class Environment: """The allowed meals depend on the `environment_config.yml` configured behaviour. Either all meals that are possible or only a limited subset.""" - self.order_and_score = OrderManager( + self.order_manager = OrderManager( order_config=self.environment_config["orders"], available_meals={ item: info @@ -225,7 +234,7 @@ class Environment: else {} ) ), - order_manager=self.order_and_score, + order_manager=self.order_manager, effect_manager_config=self.environment_config["effect_manager"], hook=self.hook, random=self.random, @@ -268,7 +277,7 @@ class Environment: ) """Counters that needs to be called in the step function via the `progress` method.""" - self.order_and_score.create_init_orders(self.env_time) + self.order_manager.create_init_orders(self.env_time) self.start_time = self.env_time """The relative env time when it started.""" self.env_time_end = self.env_time + timedelta( @@ -281,6 +290,8 @@ class Environment: str, EffectManager ] = self.counter_factory.setup_effect_manger(self.counters) + self.vector_state_generation = self.setup_vectorization() + self.hook( ENV_INITIALIZED, environment_config=env_config, @@ -665,8 +676,8 @@ class Environment: for idx, p in enumerate(self.players.values()): if not (new_positions[idx] == player_positions[idx]).all(): - p.turn(player_movement_vectors[idx]) p.move_abs(new_positions[idx]) + p.turn(player_movement_vectors[idx]) def add_player(self, player_name: str, pos: npt.NDArray = None): """Add a player to the environment. @@ -742,7 +753,7 @@ class Environment: for counter in self.progressing_counters: counter.progress(passed_time=passed_time, now=self.env_time) - self.order_and_score.progress(passed_time=passed_time, now=self.env_time) + self.order_manager.progress(passed_time=passed_time, now=self.env_time) for effect_manager in self.effect_manager.values(): effect_manager.progress(passed_time=passed_time, now=self.env_time) # self.hook(POST_STEP, passed_time=passed_time) @@ -757,7 +768,7 @@ class Environment: "players": self.players, "counters": self.counters, "score": self.score, - "orders": self.order_and_score.open_orders, + "orders": self.order_manager.open_orders, "ended": self.game_ended, "env_time": self.env_time, "remaining_time": max(self.env_time_end - self.env_time, timedelta(0)), @@ -771,7 +782,7 @@ class Environment: "counters": [c.to_dict() for c in self.counters], "kitchen": {"width": self.kitchen_width, "height": self.kitchen_height}, "score": self.score, - "orders": self.order_and_score.order_state(), + "orders": self.order_manager.order_state(), "ended": self.game_ended, "env_time": self.env_time.isoformat(), "remaining_time": max( @@ -793,6 +804,246 @@ class Environment: return json_data raise ValueError(f"No valid {player_id=}") + def setup_vectorization(self) -> VectorStateGenerationData: + grid_base_array = np.zeros( + ( + int(self.kitchen_width), + int(self.kitchen_height), + 114 + 12 + 4, # TODO calc based on item info + ), + dtype=np.float32, + ) + counter_list = [ + "Counter", + "CuttingBoard", + "ServingWindow", + "Trashcan", + "Sink", + "SinkAddon", + "Stove", + "DeepFryer", + "Oven", + ] + grid_idxs = [ + (x, y) + for x in range(int(self.kitchen_width)) + for y in range(int(self.kitchen_height)) + ] + # counters do not move + for counter in self.counters: + grid_idx = np.floor(counter.pos).astype(int) + counter_name = ( + counter.name + if isinstance(counter, CookingCounter) + else ( + repr(counter) + if isinstance(Counter, Dispenser) + else counter.__class__.__name__ + ) + ) + assert counter_name in counter_list or counter_name.endswith( + "Dispenser" + ), f"Unknown Counter {counter}" + oh_idx = len(counter_list) + if counter_name in counter_list: + oh_idx = counter_list.index(counter_name) + + one_hot = [0] * (len(counter_list) + 2) + one_hot[oh_idx] = 1 + grid_base_array[ + grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2) + ] = np.array(one_hot, dtype=np.float32) + + grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) + + for free_idx in grid_idxs: + one_hot = [0] * (len(counter_list) + 2) + one_hot[len(counter_list) + 1] = 1 + grid_base_array[ + free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2) + ] = np.array(one_hot, dtype=np.float32) + + player_info_base_array = np.zeros( + ( + 4, + 4 + 114, + ), + dtype=np.float32, + ) + order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32) + + return VectorStateGenerationData( + grid_base_array=grid_base_array, + oh_len=12, + ) + + def get_simple_vectorized_item(self, item: Item) -> npt.NDArray[float]: + name = item.name + array = np.zeros(21, dtype=np.float32) + if item.name.startswith("Burnt"): + name = name[len("Burnt") :] + array[0] = 1.0 + if name.startswith("Chopped"): + array[1] = 1.0 + name = name[len("Chopped") :] + if name in [ + "PizzaBase", + "GratedCheese", + "RawChips", + "RawPatty", + ]: + array[1] = 1.0 + name = { + "PizzaBase": "Dough", + "GratedCheese": "Cheese", + "RawChips": "Potato", + "RawPatty": "Meat", + }[name] + if name == "CookedPatty": + array[2] = 1.0 + name = "Meat" + + if name in self.vector_state_generation.meals: + idx = self.vector_state_generation.meals.index(name) + elif name in self.vector_state_generation.ingredients: + idx = len( + self.vector_state_generation.meals + ) + self.vector_state_generation.ingredients.index(name) + else: + raise ValueError(f"Unknown item {name} - {item}") + array[idx] = 1.0 + return array + + def get_vectorized_item(self, item: Item) -> npt.NDArray[float]: + item_array = np.zeros(114, dtype=np.float32) + + if isinstance(item, CookingEquipment) or item.item_info.type == ItemType.Tool: + assert ( + item.name in self.vector_state_generation.equipments + ), f"unknown equipment {item}" + idx = self.vector_state_generation.equipments.index(item.name) + item_array[idx] = 1.0 + if isinstance(item, CookingEquipment): + for s_idx, sub_item in enumerate(item.content_list): + if s_idx > 3: + print("Too much content in the content list, info dropped") + break + start_idx = len(self.vector_state_generation.equipments) + 21 + 2 + item_array[ + start_idx + (s_idx * (21)) : start_idx + ((s_idx + 1) * (21)) + ] = self.get_simple_vectorized_item(sub_item) + + else: + item_array[ + len(self.vector_state_generation.equipments) : len( + self.vector_state_generation.equipments + ) + + 21 + ] = self.get_simple_vectorized_item(item) + + item_array[ + len(self.vector_state_generation.equipments) + 21 + 1 + ] = item.progress_percentage + + if item.active_effects: + item_array[ + len(self.vector_state_generation.equipments) + 21 + 2 + ] = 1.0 # TODO percentage of fire... + + return item_array + + def get_vectorized_state( + self, player_id: str + ) -> Tuple[ + npt.NDArray[npt.NDArray[float]], + npt.NDArray[npt.NDArray[float]], + float, + npt.NDArray[float], + ]: + grid_array = self.vector_state_generation.grid_base_array.copy() + for counter in self.counters: + grid_idx = np.floor(counter.pos).astype(int) # store in counter? + if counter.occupied_by: + if isinstance(counter.occupied_by, deque): + ... + else: + item = counter.occupied_by + grid_array[ + grid_idx[0], + grid_idx[1], + 4 + self.vector_state_generation.oh_len :, + ] = self.get_vectorized_item(item) + if counter.active_effects: + grid_array[ + grid_idx[0], + grid_idx[1], + 4 + self.vector_state_generation.oh_len - 1, + ] = 1.0 # TODO percentage of fire... + + assert len(self.players) <= 4, "To much players for vector representation" + player_vec = np.zeros( + ( + 4, + 4 + 114, + ), + dtype=np.float32, + ) + player_pos = 1 + for player in self.players.values(): + if player.name == player_id: + idx = 0 + player_vec[0, :4] = np.array( + [ + player.pos[0], + player.pos[1], + player.facing_point[0], + player.facing_point[1], + ], + dtype=np.float32, + ) + else: + idx = player_pos + + if not idx: + player_pos += 1 + grid_idx = np.floor(player.pos).astype(int) # store in counter? + player_vec[idx, :4] = np.array( + [ + player.pos[0] - grid_idx[0], + player.pos[1] - grid_idx[1], + player.facing_point[0] / np.linalg.norm(player.facing_point), + player.facing_point[1] / np.linalg.norm(player.facing_point), + ], + dtype=np.float32, + ) + grid_array[grid_idx[0], grid_idx[1], idx] = 1.0 + + if player.holding: + player_vec[idx, 4:] = self.get_vectorized_item(player.holding) + + order_array = np.zeros((10 * (8 + 1)), dtype=np.float32) + + for i, order in enumerate(self.order_manager.open_orders): + if i > 9: + print("some orders are not represented in the vectorized state") + break + assert ( + order.meal.name in self.vector_state_generation.meals + ), "unknown meal in order" + idx = self.vector_state_generation.meals.index(order.meal.name) + order_array[(i * 9) + idx] = 1.0 + order_array[(i * 9) + 8] = ( + self.env_time - order.start_time + ).total_seconds() / order.max_duration.total_seconds() + + return ( + grid_array, + player_vec, + (self.env_time - self.start_time).total_seconds() + / (self.env_time_end - self.start_time).total_seconds(), + order_array, + ) + def reset_env_time(self): """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.hook(PRE_RESET_ENV_TIME) diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index b78d44af..357d2679 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -25,6 +25,46 @@ if TYPE_CHECKING: from overcooked_simulator.player import Player +@dataclasses.dataclass +class VectorStateGenerationData: + grid_base_array: npt.NDArray[npt.NDArray[float]] + oh_len: int + + number_normal_ingredients = 10 + + meals = [ + "Chips", + "FriedFish", + "Burger", + "Salad", + "TomatoSoup", + "OnionSoup", + "FishAndChips", + "Pizza", + ] + equipments = [ + "Pot", + "Pan", + "Basket", + "Peel", + "Plate", + "DirtyPlate", + "Extinguisher", + ] + ingredients = [ + "Tomato", + "Lettuce", + "Onion", + "Meat", + "Bun", + "Potato", + "Fish", + "Dough", + "Cheese", + "Sausage", + ] + + def create_init_env_time(): """Init time of the environment time, because all environments should have the same internal time.""" return datetime( -- GitLab