import json import random import time import uuid from collections import deque from datetime import timedelta from enum import Enum from pathlib import Path from uuid import uuid4 import cv2 import numpy as np import wandb import yaml from gymnasium import spaces, Env from stable_baselines3 import A2C from stable_baselines3 import DQN from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecVideoRecorder from wandb.integration.sb3 import WandbCallback from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.counters import Counter, CookingCounter, Dispenser from cooperative_cuisine.environment import ( Environment, ) from cooperative_cuisine.items import CookingEquipment from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags class SimpleActionSpace(Enum): """Enumeration of actions for simple action spaces for an RL agent.""" Up = "Up" Down = "Down" Left = "Left" Right = "Right" Interact = "Interact" Put = "Put" def get_env_action(player_id, simple_action, duration): match simple_action: case SimpleActionSpace.Up: return Action( player_id, ActionType.MOVEMENT, np.array([0, -1]), duration, ) case SimpleActionSpace.Left: return Action( player_id, ActionType.MOVEMENT, np.array([-1, 0]), duration, ) case SimpleActionSpace.Down: return Action( player_id, ActionType.MOVEMENT, np.array([0, 1]), duration, ) case SimpleActionSpace.Right: return Action( player_id, ActionType.MOVEMENT, np.array([1, 0]), duration, ) case SimpleActionSpace.Put: return Action( player_id, ActionType.PICK_UP_DROP, InterActionData.START, duration, ) case SimpleActionSpace.Interact: return Action( player_id, ActionType.INTERACT, InterActionData.START, duration, ) environment_config_path = ( ROOT_DIR / "reinforcement_learning" / "environment_config_rl.yaml" ) layout_path: Path = ROOT_DIR / "reinforcement_learning" / "rl_small.layout" item_info_path = ROOT_DIR / "reinforcement_learning" / "item_info_rl.yaml" with open(item_info_path, "r") as file: item_info = file.read() with open(layout_path, "r") as file: layout = file.read() with open(environment_config_path, "r") as file: environment_config = file.read() with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file: visualization_config = yaml.safe_load(file) visualizer: Visualizer = Visualizer(config=visualization_config) visualizer.create_player_colors(1) visualizer.set_grid_size(40) def shuffle_counters(env): sample_counter = [] other_counters = [] for counter in env.counters: if counter.__class__ != Counter: sample_counter.append(counter) else: other_counters.append(counter) new_counter_pos = [c.pos for c in sample_counter] random.shuffle(new_counter_pos) for counter, new_pos in zip(sample_counter, new_counter_pos): counter.pos = new_pos sample_counter.extend(other_counters) env.overwrite_counters(sample_counter) class EnvGymWrapper(Env): """Should enable this: observation, reward, terminated, truncated, info = env.step(action) """ metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10} def __init__(self): super().__init__() self.randomize_counter_placement = False self.use_rgb_obs = True # if False uses simple vectorized state self.full_vector_state = True self.onehot_state = True self.env: Environment = Environment( env_config=environment_config, layout_config=layout, item_info=item_info, as_files=False, env_name=uuid.uuid4().hex, ) if self.randomize_counter_placement: shuffle_counters(self.env) self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] self.visualizer = visualizer # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} self.action_space_map = {} for idx, item in enumerate(SimpleActionSpace): self.action_space_map[idx] = item self.global_step_time = 1 self.in_between_steps = 1 self.action_space = spaces.Discrete(len(self.action_space_map)) self.seen_items = [] dummy_obs = self.get_observation() min_obs_val = -1 if not self.use_rgb_obs else 0 max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20 self.observation_space = spaces.Box( low=min_obs_val, high=max_obs_val, shape=dummy_obs.shape, dtype=np.uint8 if self.use_rgb_obs else int, ) self.last_obs = dummy_obs self.step_counter = 0 self.prev_score = 0 def vectorize_item(self, item, item_list): item_one_hot = np.zeros(len(item_list)) if item is None: item_name = "None" elif isinstance(item, deque): if len(item) > 0: item_name = item[0].name else: item_name = "None" else: item_name = item.name if isinstance(item, CookingEquipment): if item.name == "Pot": if len(item.content_list) > 0: if item.content_list[0].name == "TomatoSoup": item_name = "PotDone" elif len(item.content_list) == 1: item_name = "PotOne" elif len(item.content_list) == 2: item_name = "PotTwo" elif len(item.content_list) == 3: item_name = "PotThree" if "Plate" in item.name: content_list = [i.name for i in item.content_list] match content_list: case ["TomatoSoup"]: item_name = "PlateTomatoSoup" case ["ChoppedTomato"]: item_name = "PlateChoppedTomato" case ["ChoppedLettuce"]: item_name = "PlateChoppedLettuce" case []: item_name = "Plate" case ["ChoppedLettuce", "ChoppedTomato"]: item_name = "PlateSalad" case other: assert False, f"Should not happen. {item}" assert item_name in item_list, f"Unknown item {item_name}." item_idx = item_list.index(item_name) item_one_hot[item_idx] = 1 # if item_name not in self.seen_items: # print(item, item_name) # self.seen_items.append(item_name) return item_one_hot, item_idx @staticmethod def vectorize_counter(counter, counter_list): counter_name = ( counter.name if isinstance(counter, CookingCounter) else ( repr(counter) if isinstance(Counter, Dispenser) else counter.__class__.__name__ ) ) if counter_name == "Dispenser": counter_name = f"{counter.occupied_by.name}Dispenser" assert counter_name in counter_list, f"Unknown Counter {counter}" counter_oh_idx = counter_list.index("Empty") if counter_name in counter_list: counter_oh_idx = counter_list.index(counter_name) counter_one_hot = np.zeros(len(counter_list), dtype=int) counter_one_hot[counter_oh_idx] = 1 return counter_one_hot, counter_oh_idx def get_vectorized_state_simple(self, player, onehot=True): counter_list = [ "Empty", "Counter", "PlateDispenser", "TomatoDispenser", "ServingWindow", "PlateReturn", "Trashcan", "Stove", "CuttingBoard", "LettuceDispenser", ] item_list = [ "None", "Pot", "PotOne", "PotTwo", "PotThree", "PotDone", "Tomato", "ChoppedTomato", "Plate", "PlateTomatoSoup", "PlateSalad", "Lettuce", "PlateChoppedTomato", "PlateChoppedLettuce", "ChoppedLettuce", ] grid_width, grid_height = int(self.env.kitchen_width), int( self.env.kitchen_height ) grid_base_array = np.zeros( ( grid_width, grid_height, ), dtype=int, ) grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)] if onehot: item_one_hot_length = len(item_list) counter_items = np.zeros( (grid_width, grid_height, item_one_hot_length), dtype=int ) counter_one_hot_length = len(counter_list) counters = np.zeros( (grid_width, grid_height, counter_one_hot_length), dtype=int ) else: counter_items = np.zeros((grid_width, grid_height), dtype=int) counters = np.zeros((grid_width, grid_height), dtype=int) for counter in self.env.counters: grid_idx = np.floor(counter.pos).astype(int) counter_one_hot, counter_oh_idx = self.vectorize_counter( counter, counter_list ) grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1]))) counter_item_one_hot, counter_item_oh_idx = self.vectorize_item( counter.occupied_by, item_list ) counter_items[grid_idx] = ( counter_item_one_hot if onehot else counter_item_oh_idx ) counters[grid_idx] = counter_one_hot if onehot else counter_oh_idx for free_idx in grid_idxs: grid_base_array[free_idx[0], free_idx[1]] = counter_list.index("Empty") player_pos = self.env.players[player].pos.astype(int) player_dir = self.env.players[player].facing_direction.astype(int) player_data = np.concatenate((player_pos, player_dir), axis=0) player_item_one_hot, player_item_idx = self.vectorize_item( self.env.players[player].holding, item_list ) player_item = player_item_one_hot if onehot else [player_item_idx] final = np.concatenate( ( counters.flatten(), counter_items.flatten(), player_data.flatten(), player_item, ), axis=0, ) return final def step(self, action): simple_action = self.action_space_map[action] env_action = get_env_action( self.player_id, simple_action, self.global_step_time ) self.env.perform_action(env_action) for i in range(self.in_between_steps): self.env.step( timedelta(seconds=self.global_step_time / self.in_between_steps) ) observation = self.get_observation() reward = self.env.score - self.prev_score self.prev_score = self.env.score if reward > 0.6: print("- - - - - - - - - - - - - - - - SCORED", reward) terminated = self.env.game_ended truncated = self.env.game_ended info = {} return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): del visualizer.surface_cache_dict[self.env.env_name] self.env: Environment = Environment( env_config=environment_config, layout_config=layout, item_info=item_info, as_files=False, env_name=uuid.uuid4().hex, ) if self.randomize_counter_placement: shuffle_counters(self.env) self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] info = {} obs = self.get_observation() self.prev_score = 0 return obs, info def get_observation(self): if self.use_rgb_obs: obs = self.get_env_img() else: obs = self.get_vector_state() return obs def render(self): return self.get_env_img() def close(self): pass def get_env_img(self): state = self.env.get_json_state(player_id=self.player_id) json_dict = json.loads(state) observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8) return observation def get_vector_state(self): obs = self.get_vectorized_state_simple("0", self.onehot_state) return obs def sample_random_action(self): act = self.action_space.sample() return act def main(): rl_agent_checkpoints = Path("rl_agent_checkpoints") rl_agent_checkpoints.mkdir(exist_ok=True) config = { "policy_type": "MlpPolicy", "total_timesteps": 3_000_000, # hendric sagt eher so 300_000_000 schritte "env_id": "overcooked", "number_envs_parallel": 4, } debug = True do_training = True vec_env = True number_envs_parallel = config["number_envs_parallel"] model_classes = [A2C, DQN, PPO] model_class = model_classes[1] if vec_env: env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel) else: env = EnvGymWrapper() env.render_mode = "rgb_array" if not debug: run = wandb.init( project="overcooked", config=config, sync_tensorboard=True, # auto-upload sb3's tensorboard metrics monitor_gym=True, # save_code=True, # optional ) env = VecVideoRecorder( env, f"videos/{run.id}", record_video_trigger=lambda x: x % 200_000 == 0, video_length=300, ) model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" if do_training: model = model_class( config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}", device="cpu" # n_steps=2048, # n_epochs=10, ) if debug: model.learn( total_timesteps=config["total_timesteps"], log_interval=1, progress_bar=True, ) else: checkpoint_callback = CheckpointCallback( save_freq=50_000, save_path="logs", name_prefix="rl_model", save_replay_buffer=True, save_vecnormalize=True, ) wandb_callback = WandbCallback( model_save_path=f"models/{run.id}", verbose=0, ) callback = CallbackList([checkpoint_callback, wandb_callback]) model.learn( total_timesteps=config["total_timesteps"], callback=callback, log_interval=1, progress_bar=True, ) run.finish() model.save(model_save_path) del model print("LEARNING DONE.") model = model_class.load(model_save_path) env = EnvGymWrapper() check_env(env) obs, info = env.reset() while True: action, _states = model.predict(obs, deterministic=False) obs, reward, terminated, truncated, info = env.step(int(action)) print(reward) rgb_img = env.render() cv2.imshow("env", rgb_img) cv2.waitKey(0) if terminated or truncated: obs, info = env.reset() time.sleep(1 / env.metadata["render_fps"]) if __name__ == "__main__": main()