From 75952d068fec6f756b116305a7233c1ef461bfa8 Mon Sep 17 00:00:00 2001 From: fheinrich <fheinrich@techfak.uni-bielefeld.de> Date: Mon, 12 Feb 2024 11:54:06 +0100 Subject: [PATCH] Update to rl --- .../game_content/environment_config_rl.yaml | 18 +- .../game_content/layouts/rl.layout | 9 +- .../gui_2d_vis/overcooked_gui.py | 6 +- overcooked_simulator/gym_env.py | 157 +++++++++------- .../overcooked_environment.py | 169 +++++++++++++++++- overcooked_simulator/utils.py | 21 +++ 6 files changed, 295 insertions(+), 85 deletions(-) diff --git a/overcooked_simulator/game_content/environment_config_rl.yaml b/overcooked_simulator/game_content/environment_config_rl.yaml index 45c61914..0b3b4570 100644 --- a/overcooked_simulator/game_content/environment_config_rl.yaml +++ b/overcooked_simulator/game_content/environment_config_rl.yaml @@ -1,12 +1,12 @@ plates: - clean_plates: 2 + clean_plates: 1 dirty_plates: 0 - plate_delay: [ 5, 10 ] + plate_delay: [ 2, 4 ] return_dirty: False # range of seconds until the dirty plate arrives. game: - time_limit_seconds: 400 + time_limit_seconds: 300 meals: all: true @@ -93,7 +93,7 @@ extra_setup_functions: hooks: [ completed_order ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: 100 + static_score: 1 serve_not_ordered_meals: func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' @@ -101,35 +101,35 @@ extra_setup_functions: hooks: [ serve_not_ordered_meal ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: 100 + static_score: 1 trashcan_usages: func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ trashcan_usage ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: -10 + static_score: -0.15 item_cut: func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ cutting_board_100 ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: 10 + static_score: 0.10 stepped: func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ post_step ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: -1 + static_score: -0.01 combine: func: !!python/name:overcooked_simulator.hooks.hooks_via_callback_class '' kwargs: hooks: [ drop_off_on_cooking_equipment ] callback_class: !!python/name:overcooked_simulator.scores.ScoreViaHooks '' callback_class_kwargs: - static_score: 1 + static_score: 0.10 # json_states: # func: !!python/name:overcooked_simulator.recording.class_recording_with_hooks '' # kwargs: diff --git a/overcooked_simulator/game_content/layouts/rl.layout b/overcooked_simulator/game_content/layouts/rl.layout index fd3cd8db..5115af28 100644 --- a/overcooked_simulator/game_content/layouts/rl.layout +++ b/overcooked_simulator/game_content/layouts/rl.layout @@ -1,5 +1,4 @@ -#X### -T___# -#___# -U___P -#C#W# +#X## +T__W +U__P +#C## diff --git a/overcooked_simulator/gui_2d_vis/overcooked_gui.py b/overcooked_simulator/gui_2d_vis/overcooked_gui.py index a4d28c5c..8bf25de7 100644 --- a/overcooked_simulator/gui_2d_vis/overcooked_gui.py +++ b/overcooked_simulator/gui_2d_vis/overcooked_gui.py @@ -12,6 +12,7 @@ import pygame import pygame_gui import requests import yaml +from pygame._sdl2 import get_drivers from websockets.sync.client import connect from overcooked_simulator import ROOT_DIR @@ -30,6 +31,9 @@ from overcooked_simulator.utils import ( add_list_of_manager_ids_arguments, ) +for driver in get_drivers(): + print(driver) + class MenuStates(Enum): Start = "Start" @@ -970,8 +974,8 @@ class PyGameGUI: clock = pygame.time.Clock() - self.reset_window_size() self.init_ui_elements() + self.reset_window_size() self.manage_button_visibility() self.update_selection_elements() diff --git a/overcooked_simulator/gym_env.py b/overcooked_simulator/gym_env.py index 41c3474d..47dbc108 100644 --- a/overcooked_simulator/gym_env.py +++ b/overcooked_simulator/gym_env.py @@ -1,7 +1,6 @@ import json import random import time -from copy import deepcopy from datetime import timedelta from enum import Enum from pathlib import Path @@ -17,6 +16,7 @@ from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.vec_env import VecVideoRecorder from wandb.integration.sb3 import WandbCallback from overcooked_simulator import ROOT_DIR @@ -134,12 +134,18 @@ with open(environment_config_path, "r") as file: with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file: visualization_config = yaml.safe_load(file) -vanilla_env: Environment = Environment( - env_config=environment_config, - layout_config=layout, - item_info=item_info, - as_files=False, -) + +def shuffle_counters(env): + sample_counter = [] + sample_counter = [] + for counter in env.counters: + if counter.__class__ != Counter: + sample_counter.append(counter) + new_counter_pos = [c.pos for c in sample_counter] + random.shuffle(new_counter_pos) + for counter, new_pos in zip(sample_counter, new_counter_pos): + counter.pos = new_pos + env.vector_state_generation = env.setup_vectorization() class EnvGymWrapper(Env): @@ -147,29 +153,33 @@ class EnvGymWrapper(Env): observation, reward, terminated, truncated, info = env.step(action) """ - metadata = {"render_modes": ["human"], "render_fps": 30} + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 30} def __init__(self): super().__init__() self.gridsize = 20 - self.env = deepcopy(vanilla_env) - # sample_counter = [] - # for counter in self.env.counters: - # if counter.__class__ != Counter: - # sample_counter.append(counter) - # new_counter_pos = [c.pos for c in sample_counter] - # random.shuffle(new_counter_pos) - # for counter, new_pos in zip(sample_counter, new_counter_pos): - # counter.pos = new_pos - # self.env.vector_state_generation = self.env.setup_vectorization() + self.randomize_counter_placement = False + self.use_rgb_obs = False + + self.env: Environment = Environment( + env_config=environment_config, + layout_config=layout, + item_info=item_info, + as_files=False, + ) + + if self.randomize_counter_placement: + shuffle_counters(self.env) self.visualizer: Visualizer = Visualizer(config=visualization_config) self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] + self.env.setup_vectorization() + self.visualizer.create_player_colors(1) # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} @@ -184,9 +194,8 @@ class EnvGymWrapper(Env): # Example for using image as input (channel-first; channel-last also works): dummy_obs = self.get_observation() - # dummy_obs = self.get_vector_state() self.observation_space = spaces.Box( - low=0, high=1, shape=dummy_obs.shape, dtype=float + low=-1, high=8, shape=dummy_obs.shape, dtype=int ) self.last_obs = dummy_obs @@ -219,27 +228,22 @@ class EnvGymWrapper(Env): return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): - # self.env: Environment = Environment( - # env_config=environment_config, - # layout_config=layout, - # item_info=item_info, - # as_files=False - # ) + self.env: Environment = Environment( + env_config=environment_config, + layout_config=layout, + item_info=item_info, + as_files=False, + ) - self.env = deepcopy(vanilla_env) - # sample_counter = [] - # for counter in self.env.counters: - # if counter.__class__ != Counter: - # sample_counter.append(counter) - # new_counter_pos = [c.pos for c in sample_counter] - # random.shuffle(new_counter_pos) - # for counter, new_pos in zip(sample_counter, new_counter_pos): - # counter.pos = new_pos - # self.env.vector_state_generation = self.env.setup_vectorization() + if self.randomize_counter_placement: + shuffle_counters(self.env) self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] + + self.env.setup_vectorization() + info = {} obs = self.get_observation() @@ -248,16 +252,18 @@ class EnvGymWrapper(Env): return obs, info def get_observation(self): - # obs = self.get_env_img(self.gridsize) - obs = self.get_vector_state() + if self.use_rgb_obs: + obs = self.get_env_img(self.gridsize) + else: + obs = self.get_vector_state() return obs + def render(self): observation = self.get_env_img(self.gridsize) - img = observation.transpose((1, 2, 0))[:, :, ::-1] - # print(img.shape) - img = cv2.resize(img, (img.shape[1] * 5, img.shape[0] * 5)) - cv2.imshow("Overcooked", img) - cv2.waitKey(1) + img = (observation * 255.0).astype(np.uint8) + img = img.transpose((1, 2, 0)) + img = cv2.resize(img, (img.shape[1], img.shape[0])) + return img def close(self): pass @@ -268,20 +274,23 @@ class EnvGymWrapper(Env): observation = self.visualizer.get_state_image( grid_size=gridsize, state=json_dict ).transpose((1, 0, 2)) - return observation.transpose((2, 0, 1)) / 255.0 + return (observation.transpose((2, 0, 1)) / 255.0).astype(np.float32) def get_vector_state(self): - grid, player, env_time, orders = self.env.get_vectorized_state("0") + # grid, player, env_time, orders = self.env.get_vectorized_state_full("0") + # + # + # obs = np.concatenate( + # [grid.flatten(), player.flatten()], axis=0, dtype=np.float32 + # ) + # return obs - obs = np.concatenate([grid.flatten(), player.flatten()], axis=0) + obs = self.env.get_vectorized_state_simple("0") return obs - # flatten: grid + player - # concatenate all (env_time to array) def sample_random_action(self): act = self.action_space.sample() return act - # return np.random.randint(len(self.action_space_map)) def main(): @@ -290,14 +299,15 @@ def main(): config = { "policy_type": "MlpPolicy", - "total_timesteps": 100_000, # hendric sagt eher so 300_000_000 schritte + "total_timesteps": 30_000_000, # hendric sagt eher so 300_000_000 schritte "env_id": "overcooked", + "number_envs_parallel": 16, } - debug = True + debug = False do_training = True vec_env = True - number_envs_parallel = 8 + number_envs_parallel = config["number_envs_parallel"] model_classes = [A2C, DQN, PPO] model_class = model_classes[2] @@ -307,12 +317,34 @@ def main(): else: env = EnvGymWrapper() - model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" + env.render_mode = "rgb_array" + if not debug: + run = wandb.init( + project="overcooked", + config=config, + sync_tensorboard=True, # auto-upload sb3's tensorboard metrics + monitor_gym=True + # save_code=True, # optional + ) + + env = VecVideoRecorder( + env, + f"videos/{run.id}", + record_video_trigger=lambda x: x % 100_000 == 0, + video_length=300, + ) + + model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" if do_training: model = model_class( - config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}" + config["policy_type"], + env, + verbose=1, + tensorboard_log=f"runs/{0}", + # n_steps=2048, + # n_epochs=10, ) if debug: model.learn( @@ -321,16 +353,8 @@ def main(): progress_bar=True, ) else: - run = wandb.init( - project="overcooked", - config=config, - sync_tensorboard=True, # auto-upload sb3's tensorboard metrics - monitor_gym=True - # save_code=True, # optional - ) - checkpoint_callback = CheckpointCallback( - save_freq=1000, + save_freq=50_000, save_path="./logs/", name_prefix="rl_model", save_replay_buffer=True, @@ -356,14 +380,17 @@ def main(): model = model_class.load(model_save_path) env = EnvGymWrapper() + check_env(env) obs, info = env.reset() while True: - time.sleep(1 / 10) + time.sleep(1 / 30) action, _states = model.predict(obs, deterministic=False) obs, reward, terminated, truncated, info = env.step(int(action)) print(reward) - env.render() + rgb_img = env.render() + cv2.imshow("env", rgb_img) + cv2.waitKey(0) if terminated or truncated: obs, info = env.reset() diff --git a/overcooked_simulator/overcooked_environment.py b/overcooked_simulator/overcooked_environment.py index 54de3fb1..c394624a 100644 --- a/overcooked_simulator/overcooked_environment.py +++ b/overcooked_simulator/overcooked_environment.py @@ -48,7 +48,8 @@ from overcooked_simulator.hooks import ( ACTION_ON_NOT_REACHABLE_COUNTER, ACTION_PUT, ACTION_INTERACT_START, - ITEM_INFO_CONFIG, POST_STEP, + ITEM_INFO_CONFIG, + POST_STEP, ) from overcooked_simulator.order import ( OrderManager, @@ -906,9 +907,11 @@ class Environment: if name in self.vector_state_generation.meals: idx = 3 + self.vector_state_generation.meals.index(name) elif name in self.vector_state_generation.ingredients: - idx = 3 + len( - self.vector_state_generation.meals - ) + self.vector_state_generation.ingredients.index(name) + idx = ( + 3 + + len(self.vector_state_generation.meals) + + self.vector_state_generation.ingredients.index(name) + ) else: raise ValueError(f"Unknown item {name} - {item}") array[idx] = 1.0 @@ -952,7 +955,7 @@ class Environment: return item_array - def get_vectorized_state( + def get_vectorized_state_full( self, player_id: str ) -> Tuple[ npt.NDArray[npt.NDArray[float]], @@ -1044,6 +1047,162 @@ class Environment: order_array, ) + # def setup_vectorization_simple(self) -> VectorStateGenerationDataSimple: + # num_per_item = 114 + # num_per_counter = 12 + # num_players = 4 + # grid_base_array = np.zeros( + # ( + # int(self.kitchen_width), + # int(self.kitchen_height), + # num_per_item + # + num_per_counter + # + num_players, # TODO calc based on item info + # ), + # dtype=np.float32, + # ) + # counter_list = [ + # "Counter", + # "CuttingBoard", + # "ServingWindow", + # "Trashcan", + # "Sink", + # "SinkAddon", + # "Stove", + # "DeepFryer", + # "Oven", + # ] + # grid_idxs = [ + # (x, y) + # for x in range(int(self.kitchen_width)) + # for y in range(int(self.kitchen_height)) + # ] + # # counters do not move + # for counter in self.counters: + # grid_idx = np.floor(counter.pos).astype(int) + # counter_name = ( + # counter.name + # if isinstance(counter, CookingCounter) + # else ( + # repr(counter) + # if isinstance(Counter, Dispenser) + # else counter.__class__.__name__ + # ) + # ) + # assert counter_name in counter_list or counter_name.endswith( + # "Dispenser" + # ), f"Unknown Counter {counter}" + # oh_idx = len(counter_list) + # if counter_name in counter_list: + # oh_idx = counter_list.index(counter_name) + # + # one_hot = [0] * (len(counter_list) + 2) + # one_hot[oh_idx] = 1 + # grid_base_array[ + # grid_idx[0], grid_idx[1], 4 : 4 + (len(counter_list) + 2) + # ] = np.array(one_hot, dtype=np.float32) + # + # grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) + # + # for free_idx in grid_idxs: + # one_hot = [0] * (len(counter_list) + 2) + # one_hot[len(counter_list) + 1] = 1 + # grid_base_array[ + # free_idx[0], free_idx[1], 4 : 4 + (len(counter_list) + 2) + # ] = np.array(one_hot, dtype=np.float32) + # + # player_info_base_array = np.zeros( + # ( + # 4, + # 4 + 114, + # ), + # dtype=np.float32, + # ) + # order_base_array = np.zeros((10 * (8 + 1)), dtype=np.float32) + # + # return VectorStateGenerationData( + # grid_base_array=grid_base_array, + # oh_len=12, + # ) + + def get_vectorized_state_simple(self, player): + item_list = ["Pot", "Tomato", "ChoppedTomato", "Plate"] + counter_list = [ + "Counter", + "PlateDispenser", + "TomatoDispenser", + "ServingWindow", + "PlateReturn", + "Trashcan", + "Stove", + "CuttingBoard", + ] + player_pos = self.players[player].pos + player_dir = self.players[player].facing_direction + + grid_width, grid_height = int(self.kitchen_width), int(self.kitchen_height) + + counter_one_hot_length = len(counter_list) + 1 # one for empty field + grid_base_array = np.zeros( + ( + grid_width, + grid_height, + ), + dtype=int, + ) + + grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)] + + # counters do not move + for counter in self.counters: + grid_idx = np.floor(counter.pos).astype(int) + counter_name = ( + counter.name + if isinstance(counter, CookingCounter) + else ( + repr(counter) + if isinstance(Counter, Dispenser) + else counter.__class__.__name__ + ) + ) + if counter_name == "Dispenser": + counter_name = f"{counter.occupied_by.name}Dispenser" + assert counter_name in counter_list, f"Unknown Counter {counter}" + + counter_oh_idx = counter_one_hot_length + if counter_name in counter_list: + counter_oh_idx = counter_list.index(counter_name) + + grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx + grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) + + for free_idx in grid_idxs: + grid_base_array[free_idx[0], free_idx[1]] = counter_one_hot_length - 1 + + counter_grid_one_hot = np.zeros( + (grid_width, grid_height, counter_one_hot_length), dtype=int + ) + for x in range(grid_width): + for y in range(grid_height): + counter_type_idx = grid_base_array[x, y] + counter_grid_one_hot[x, y, counter_type_idx] = 1 + + player_data = np.concatenate((player_pos, player_dir), axis=0) + + items_one_hot_length = len(item_list) + 1 + item_one_hot = np.zeros(items_one_hot_length, dtype=int) + player_item = self.players[player].holding + player_item_idx = items_one_hot_length - 1 + if player_item: + if player_item.name in item_list: + player_item_idx = item_list.index(player_item.name) + item_one_hot[player_item_idx] = 1 + + final = np.concatenate( + (counter_grid_one_hot.flatten(), player_data, item_one_hot), axis=0 + ) + return final + def reset_env_time(self): """Reset the env time to the initial time, defined by `create_init_env_time`.""" self.hook(PRE_RESET_ENV_TIME) diff --git a/overcooked_simulator/utils.py b/overcooked_simulator/utils.py index 357d2679..4ef8b868 100644 --- a/overcooked_simulator/utils.py +++ b/overcooked_simulator/utils.py @@ -65,6 +65,27 @@ class VectorStateGenerationData: ] +@dataclasses.dataclass +class VectorStateGenerationDataSimple: + grid_base_array: npt.NDArray[npt.NDArray[float]] + oh_len: int + + number_normal_ingredients = 1 + + meals = [ + "TomatoSoup", + ] + equipments = [ + "Pot", + "Plate", + "DirtyPlate", + "Extinguisher", + ] + ingredients = [ + "Tomato", + ] + + def create_init_env_time(): """Init time of the environment time, because all environments should have the same internal time.""" return datetime( -- GitLab