Skip to content
Snippets Groups Projects
gym_env.py 16.5 KiB
Newer Older
  • Learn to ignore specific revisions
  • import json
    import random
    import time
    
    from collections import deque
    from datetime import timedelta
    from enum import Enum
    from pathlib import Path
    
    
    import cv2
    import numpy as np
    import wandb
    import yaml
    from gymnasium import spaces, Env
    from stable_baselines3 import A2C
    from stable_baselines3 import DQN
    from stable_baselines3 import PPO
    from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
    from stable_baselines3.common.env_checker import check_env
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.vec_env import VecVideoRecorder
    from wandb.integration.sb3 import WandbCallback
    
    
    from cooperative_cuisine import ROOT_DIR
    
    from cooperative_cuisine.action import ActionType, InterActionData, Action
    
    from cooperative_cuisine.counters import Counter, CookingCounter, Dispenser
    from cooperative_cuisine.environment import (
    
        Environment,
    )
    
    from cooperative_cuisine.items import CookingEquipment
    
    from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags
    
    
    
    class SimpleActionSpace(Enum):
    
        """Enumeration of actions for simple action spaces for an RL agent."""
    
    
        Up = "Up"
        Down = "Down"
        Left = "Left"
        Right = "Right"
        Interact = "Interact"
        Put = "Put"
    
    
    
    def get_env_action(player_id, simple_action, duration):
        match simple_action:
            case SimpleActionSpace.Up:
                return Action(
                    player_id,
                    ActionType.MOVEMENT,
                    np.array([0, -1]),
                    duration,
                )
    
            case SimpleActionSpace.Left:
                return Action(
                    player_id,
                    ActionType.MOVEMENT,
                    np.array([-1, 0]),
                    duration,
                )
            case SimpleActionSpace.Down:
                return Action(
                    player_id,
                    ActionType.MOVEMENT,
                    np.array([0, 1]),
                    duration,
                )
            case SimpleActionSpace.Right:
                return Action(
                    player_id,
                    ActionType.MOVEMENT,
                    np.array([1, 0]),
                    duration,
                )
            case SimpleActionSpace.Put:
                return Action(
                    player_id,
    
                    ActionType.PICK_UP_DROP,
    
                    InterActionData.START,
                    duration,
                )
            case SimpleActionSpace.Interact:
                return Action(
                    player_id,
                    ActionType.INTERACT,
                    InterActionData.START,
                    duration,
                )
    
    
    environment_config_path = (
    
            ROOT_DIR / "reinforcement_learning" / "environment_config_rl.yaml"
    
    )
    layout_path: Path = ROOT_DIR / "reinforcement_learning" / "rl_small.layout"
    item_info_path = ROOT_DIR / "reinforcement_learning" / "item_info_rl.yaml"
    with open(item_info_path, "r") as file:
        item_info = file.read()
    with open(layout_path, "r") as file:
        layout = file.read()
    with open(environment_config_path, "r") as file:
        environment_config = file.read()
    
    with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file:
    
        visualization_config = yaml.safe_load(file)
    
    
    visualizer: Visualizer = Visualizer(config=visualization_config)
    visualizer.create_player_colors(1)
    visualizer.set_grid_size(40)
    
    
    
    def shuffle_counters(env):
        sample_counter = []
    
        for counter in env.counters:
            if counter.__class__ != Counter:
                sample_counter.append(counter)
    
                other_counters.append(counter)
    
        new_counter_pos = [c.pos for c in sample_counter]
        random.shuffle(new_counter_pos)
        for counter, new_pos in zip(sample_counter, new_counter_pos):
            counter.pos = new_pos
    
        sample_counter.extend(other_counters)
        env.overwrite_counters(sample_counter)
    
    
    
    class EnvGymWrapper(Env):
        """Should enable this:
        observation, reward, terminated, truncated, info = env.step(action)
        """
    
    
        metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}
    
    
        def __init__(self):
            super().__init__()
    
    
    
            self.randomize_counter_placement = False
    
            self.use_rgb_obs = True  # if False uses simple vectorized state
    
            self.full_vector_state = True
    
            self.onehot_state = True
    
    
            self.env: Environment = Environment(
                env_config=environment_config,
                layout_config=layout,
                item_info=item_info,
                as_files=False,
    
            )
    
            if self.randomize_counter_placement:
                shuffle_counters(self.env)
    
            self.player_name = str(0)
            self.env.add_player(self.player_name)
            self.player_id = list(self.env.players.keys())[0]
    
    
            self.visualizer = visualizer
    
            # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
            self.action_space_map = {}
            for idx, item in enumerate(SimpleActionSpace):
                self.action_space_map[idx] = item
            self.global_step_time = 1
            self.in_between_steps = 1
    
            self.action_space = spaces.Discrete(len(self.action_space_map))
    
    
            dummy_obs = self.get_observation()
    
            min_obs_val = -1 if not self.use_rgb_obs else 0
            max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
    
            self.observation_space = spaces.Box(
                low=min_obs_val,
                high=max_obs_val,
                shape=dummy_obs.shape,
    
                dtype=np.uint8 if self.use_rgb_obs else int,
    
            )
    
            self.last_obs = dummy_obs
    
            self.step_counter = 0
            self.prev_score = 0
    
    
        def vectorize_item(self, item, item_list):
            item_one_hot = np.zeros(len(item_list))
            if item is None:
                item_name = "None"
            elif isinstance(item, deque):
                if len(item) > 0:
                    item_name = item[0].name
                else:
                    item_name = "None"
            else:
                item_name = item.name
    
            if isinstance(item, CookingEquipment):
                if item.name == "Pot":
                    if len(item.content_list) > 0:
                        if item.content_list[0].name == "TomatoSoup":
                            item_name = "PotDone"
                        elif len(item.content_list) == 1:
                            item_name = "PotOne"
                        elif len(item.content_list) == 2:
                            item_name = "PotTwo"
                        elif len(item.content_list) == 3:
                            item_name = "PotThree"
    
                if "Plate" in item.name:
                    content_list = [i.name for i in item.content_list]
                    match content_list:
                        case ["TomatoSoup"]:
                            item_name = "PlateTomatoSoup"
                        case ["ChoppedTomato"]:
                            item_name = "PlateChoppedTomato"
                        case ["ChoppedLettuce"]:
                            item_name = "PlateChoppedLettuce"
                        case []:
                            item_name = "Plate"
                        case ["ChoppedLettuce", "ChoppedTomato"]:
                            item_name = "PlateSalad"
                        case other:
                            assert False, f"Should not happen. {item}"
    
            assert item_name in item_list, f"Unknown item {item_name}."
            item_idx = item_list.index(item_name)
            item_one_hot[item_idx] = 1
    
    
            # if item_name not in self.seen_items:
            #     print(item, item_name)
            #     self.seen_items.append(item_name)
    
    
            return item_one_hot, item_idx
    
        @staticmethod
        def vectorize_counter(counter, counter_list):
            counter_name = (
                counter.name
                if isinstance(counter, CookingCounter)
                else (
                    repr(counter)
                    if isinstance(Counter, Dispenser)
                    else counter.__class__.__name__
                )
            )
            if counter_name == "Dispenser":
                counter_name = f"{counter.occupied_by.name}Dispenser"
            assert counter_name in counter_list, f"Unknown Counter {counter}"
    
            counter_oh_idx = counter_list.index("Empty")
            if counter_name in counter_list:
                counter_oh_idx = counter_list.index(counter_name)
    
            counter_one_hot = np.zeros(len(counter_list), dtype=int)
            counter_one_hot[counter_oh_idx] = 1
            return counter_one_hot, counter_oh_idx
    
    
        def get_vectorized_state_simple(self, player, onehot=True):
            counter_list = [
    
                "Empty",
    
                "Counter",
                "PlateDispenser",
                "TomatoDispenser",
                "ServingWindow",
                "PlateReturn",
                "Trashcan",
                "Stove",
                "CuttingBoard",
    
                "LettuceDispenser",
    
            item_list = [
                "None",
                "Pot",
                "PotOne",
                "PotTwo",
                "PotThree",
                "PotDone",
                "Tomato",
                "ChoppedTomato",
                "Plate",
                "PlateTomatoSoup",
    
                "PlateSalad",
                "Lettuce",
                "PlateChoppedTomato",
                "PlateChoppedLettuce",
                "ChoppedLettuce",
    
            grid_width, grid_height = int(self.env.kitchen_width), int(
                self.env.kitchen_height
            )
    
            grid_base_array = np.zeros(
                (
                    grid_width,
                    grid_height,
                ),
                dtype=int,
            )
            grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)]
    
            if onehot:
    
                item_one_hot_length = len(item_list)
                counter_items = np.zeros(
                    (grid_width, grid_height, item_one_hot_length), dtype=int
                )
                counter_one_hot_length = len(counter_list)
                counters = np.zeros(
                    (grid_width, grid_height, counter_one_hot_length), dtype=int
                )
    
            else:
    
                counter_items = np.zeros((grid_width, grid_height), dtype=int)
                counters = np.zeros((grid_width, grid_height), dtype=int)
    
    
            for counter in self.env.counters:
                grid_idx = np.floor(counter.pos).astype(int)
    
    
                counter_one_hot, counter_oh_idx = self.vectorize_counter(
                    counter, counter_list
                )
                grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
    
                grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))
    
    
                counter_item_one_hot, counter_item_oh_idx = self.vectorize_item(
                    counter.occupied_by, item_list
                )
                counter_items[grid_idx] = (
                    counter_item_one_hot if onehot else counter_item_oh_idx
                )
                counters[grid_idx] = counter_one_hot if onehot else counter_oh_idx
    
    
            for free_idx in grid_idxs:
    
                grid_base_array[free_idx[0], free_idx[1]] = counter_list.index("Empty")
    
            player_pos = self.env.players[player].pos.astype(int)
            player_dir = self.env.players[player].facing_direction.astype(int)
            player_data = np.concatenate((player_pos, player_dir), axis=0)
    
            player_item_one_hot, player_item_idx = self.vectorize_item(
                self.env.players[player].holding, item_list
    
            player_item = player_item_one_hot if onehot else [player_item_idx]
    
            final = np.concatenate(
    
                    counters.flatten(),
                    counter_items.flatten(),
                    player_data.flatten(),
                    player_item,
    
                axis=0,
    
            return final
    
    
        def step(self, action):
            simple_action = self.action_space_map[action]
            env_action = get_env_action(
                self.player_id, simple_action, self.global_step_time
            )
            self.env.perform_action(env_action)
    
            for i in range(self.in_between_steps):
                self.env.step(
                    timedelta(seconds=self.global_step_time / self.in_between_steps)
                )
    
            observation = self.get_observation()
    
            reward = self.env.score - self.prev_score
            self.prev_score = self.env.score
    
    
            if reward > 0.6:
                print("- - - - - - - - - - - - - - - - SCORED", reward)
    
    
            terminated = self.env.game_ended
            truncated = self.env.game_ended
            info = {}
    
            return observation, reward, terminated, truncated, info
    
        def reset(self, seed=None, options=None):
    
            del visualizer.surface_cache_dict[self.env.env_name]
    
            self.env: Environment = Environment(
                env_config=environment_config,
                layout_config=layout,
                item_info=item_info,
                as_files=False,
    
            )
    
            if self.randomize_counter_placement:
                shuffle_counters(self.env)
    
            self.player_name = str(0)
            self.env.add_player(self.player_name)
            self.player_id = list(self.env.players.keys())[0]
    
            info = {}
            obs = self.get_observation()
    
            self.prev_score = 0
    
            return obs, info
    
        def get_observation(self):
            if self.use_rgb_obs:
    
                obs = self.get_env_img()
    
            else:
                obs = self.get_vector_state()
            return obs
    
        def render(self):
    
    
        def close(self):
            pass
    
    
        def get_env_img(self):
    
            state = self.env.get_json_state(player_id=self.player_id)
            json_dict = json.loads(state)
    
            observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8)
            return observation
    
    
        def get_vector_state(self):
    
            obs = self.get_vectorized_state_simple("0", self.onehot_state)
    
            return obs
    
        def sample_random_action(self):
            act = self.action_space.sample()
            return act
    
    
    def main():
        rl_agent_checkpoints = Path("rl_agent_checkpoints")
        rl_agent_checkpoints.mkdir(exist_ok=True)
    
        config = {
            "policy_type": "MlpPolicy",
    
            "total_timesteps": 3_000_000,  # hendric sagt eher so 300_000_000 schritte
    
            "env_id": "overcooked",
    
            "number_envs_parallel": 4,
    
        do_training = True
        vec_env = True
        number_envs_parallel = config["number_envs_parallel"]
    
        model_classes = [A2C, DQN, PPO]
    
        model_class = model_classes[1]
    
    
        if vec_env:
            env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel)
        else:
            env = EnvGymWrapper()
    
        env.render_mode = "rgb_array"
    
        if not debug:
            run = wandb.init(
                project="overcooked",
                config=config,
                sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
                monitor_gym=True,
                # save_code=True,  # optional
            )
    
            env = VecVideoRecorder(
                env,
                f"videos/{run.id}",
                record_video_trigger=lambda x: x % 200_000 == 0,
                video_length=300,
            )
    
        model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
    
        if do_training:
            model = model_class(
                config["policy_type"],
                env,
                verbose=1,
                tensorboard_log=f"runs/{0}",
                device="cpu"
                # n_steps=2048,
                # n_epochs=10,
            )
            if debug:
                model.learn(
                    total_timesteps=config["total_timesteps"],
                    log_interval=1,
                    progress_bar=True,
                )
            else:
                checkpoint_callback = CheckpointCallback(
                    save_freq=50_000,
                    save_path="logs",
                    name_prefix="rl_model",
                    save_replay_buffer=True,
                    save_vecnormalize=True,
                )
                wandb_callback = WandbCallback(
                    model_save_path=f"models/{run.id}",
                    verbose=0,
                )
    
                callback = CallbackList([checkpoint_callback, wandb_callback])
                model.learn(
                    total_timesteps=config["total_timesteps"],
                    callback=callback,
                    log_interval=1,
                    progress_bar=True,
                )
                run.finish()
            model.save(model_save_path)
    
            del model
        print("LEARNING DONE.")
    
        model = model_class.load(model_save_path)
        env = EnvGymWrapper()
    
        check_env(env)
        obs, info = env.reset()
        while True:
            action, _states = model.predict(obs, deterministic=False)
            obs, reward, terminated, truncated, info = env.step(int(action))
            print(reward)
            rgb_img = env.render()
            cv2.imshow("env", rgb_img)
            cv2.waitKey(0)
            if terminated or truncated:
                obs, info = env.reset()
            time.sleep(1 / env.metadata["render_fps"])
    
    
    if __name__ == "__main__":
        main()