Skip to content
Snippets Groups Projects
gym_env.py 9.61 KiB
Newer Older
  • Learn to ignore specific revisions
  • import json
    import random
    
    from datetime import timedelta
    from enum import Enum
    from pathlib import Path
    
    
    import cv2
    import numpy as np
    import yaml
    from gymnasium import spaces, Env
    
    from cooperative_cuisine import ROOT_DIR
    
    from cooperative_cuisine.action import ActionType, InterActionData, Action
    
    from cooperative_cuisine.environment import (
    
        Environment,
    )
    
    from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags
    
    
    
    class SimpleActionSpace(Enum):
    
        """Enumeration of actions for simple action spaces for an RL agent."""
    
    
        Up = "Up"
        Down = "Down"
        Left = "Left"
        Right = "Right"
        Interact = "Interact"
        Put = "Put"
    
    
    
    def get_env_action(player_id, simple_action, duration):
    
    
        Args:
            player_id: id of the player
            simple_action: an action in the form of a SimpleActionSpace
            duration: for how long an action should be conducted
    
        Returns: a concrete action
    
    
        match simple_action:
            case SimpleActionSpace.Up:
                return Action(
                    player_id,
                    ActionType.MOVEMENT,
                    np.array([0, -1]),
                    duration,
                )
    
            case SimpleActionSpace.Left:
                return Action(
                    player_id,
                    ActionType.MOVEMENT,
                    np.array([-1, 0]),
                    duration,
                )
            case SimpleActionSpace.Down:
                return Action(
                    player_id,
                    ActionType.MOVEMENT,
                    np.array([0, 1]),
                    duration,
                )
            case SimpleActionSpace.Right:
                return Action(
                    player_id,
                    ActionType.MOVEMENT,
                    np.array([1, 0]),
                    duration,
                )
            case SimpleActionSpace.Put:
                return Action(
                    player_id,
    
                    ActionType.PICK_UP_DROP,
    
                    InterActionData.START,
                    duration,
                )
            case SimpleActionSpace.Interact:
                return Action(
                    player_id,
                    ActionType.INTERACT,
                    InterActionData.START,
                    duration,
                )
    
    
    
    with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file:
    
        visualization_config = yaml.safe_load(file)
    
    
    visualizer: Visualizer = Visualizer(config=visualization_config)
    visualizer.create_player_colors(1)
    visualizer.set_grid_size(40)
    
    
    
    def shuffle_counters(env):
    
        """
        Shuffles the counters of an environment
    
        sample_counter = []
    
        for counter in env.counters:
            if counter.__class__ != Counter:
                sample_counter.append(counter)
    
                other_counters.append(counter)
    
        new_counter_pos = [c.pos for c in sample_counter]
        random.shuffle(new_counter_pos)
        for counter, new_pos in zip(sample_counter, new_counter_pos):
            counter.pos = new_pos
    
        sample_counter.extend(other_counters)
        env.overwrite_counters(sample_counter)
    
    class StateToObservationConverter:
    
        """
        Abstract definition of a class that gets and environment and outputs a state representation for rl
        """
    
        def setup(self, env, item_info):
    
            ...
    
        @abstractmethod
        def convert_state_to_observation(self, state) -> np.ndarray:
            ...
    
    
    
    class EnvGymWrapper(Env):
        """Should enable this:
        observation, reward, terminated, truncated, info = env.step(action)
        """
    
    
        metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}
    
            """
            Initializes all necessary variables
    
    
            Args:
                config:gets the rl and environment configuration from hydra
    
            super().__init__()
    
    
            self.randomize_counter_placement = False
    
            self.use_rgb_obs = False  # if False uses simple vectorized state
    
            self.full_vector_state = True
    
            config_env = OmegaConf.to_container(config.environment, resolve=True)
            config_item_info = OmegaConf.to_container(config.item_info, resolve=True)
    
            for val in config_env['hook_callbacks']:
                config_env['hook_callbacks'][val]["callback_class"] = instantiate(config_env['hook_callbacks'][val]["callback_class"])
            config_env["orders"]["order_gen_class"] = instantiate(config_env["orders"]["order_generator"])
    
            self.config_env = config_env
            self.config_item_info = config_item_info
    
            layout_file = config_env["layout_name"]
    
            layout_path: Path = ROOT_DIR / layout_file
    
            with open(layout_path, "r") as file:
                self.layout = file.read()
    
            self.env: Environment = Environment(
    
                env_config=deepcopy(config_env),
    
                layout_config=self.layout,
    
                item_info=deepcopy(config_item_info),
    
                as_files=False,
    
                yaml_already_loaded=True,
    
            )
    
            if self.randomize_counter_placement:
                shuffle_counters(self.env)
    
            self.player_name = str(0)
            self.env.add_player(self.player_name)
            self.player_id = list(self.env.players.keys())[0]
    
    
            self.visualizer = visualizer
    
            # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
            self.action_space_map = {}
            for idx, item in enumerate(SimpleActionSpace):
                self.action_space_map[idx] = item
            self.global_step_time = 1
            self.in_between_steps = 1
    
            self.action_space = spaces.Discrete(len(self.action_space_map))
    
    
            self.seen_items = []
    
            self.converter = instantiate(config.additional_configs.state_converter)
    
            # self.converter.setup could also get the item info config in order to get all the possible items.
            self.converter.setup(self.env, self.config_item_info)
    
            if hasattr(self.converter, "onehot"):
    
            else:
                self.onehot_state = 'onehot' in config.additional_configs.state_converter.lower()
    
            dummy_obs = self.get_observation()
    
            min_obs_val = -1 if not self.use_rgb_obs else 0
            max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
    
            self.observation_space = spaces.Box(
                low=min_obs_val,
                high=max_obs_val,
                shape=dummy_obs.shape,
    
                dtype=np.uint8 if self.use_rgb_obs else int,
    
            )
    
            self.last_obs = dummy_obs
            self.step_counter = 0
            self.prev_score = 0
    
        def step(self, action):
    
            """
            takes one step in the environment and returns the observation, reward, info whether terminated, truncated
            and additional information
            """
    
            # this is simply a work-around to enable no action which is necessary for the play_gym.py
            if action == 8:
                observation = self.get_observation()
                reward = self.env.score - self.prev_score
                terminated = self.env.game_ended
                truncated = self.env.game_ended
                info = {}
                return observation, reward, terminated, truncated, info
    
            simple_action = self.action_space_map[action]
            env_action = get_env_action(
                self.player_id, simple_action, self.global_step_time
            )
            self.env.perform_action(env_action)
    
            for i in range(self.in_between_steps):
                self.env.step(
                    timedelta(seconds=self.global_step_time / self.in_between_steps)
                )
    
            observation = self.get_observation()
    
            reward = self.env.score - self.prev_score
            self.prev_score = self.env.score
    
    
            if reward > 0.6:
                print("- - - - - - - - - - - - - - - - SCORED", reward)
    
    
            terminated = self.env.game_ended
            truncated = self.env.game_ended
            info = {}
    
            return observation, reward, terminated, truncated, info
    
        def reset(self, seed=None, options=None):
    
            """
            Resets the environment according to the configs
            """
    
            if self.env.env_name in visualizer.surface_cache_dict:
                del visualizer.surface_cache_dict[self.env.env_name]
    
            self.env: Environment = Environment(
    
                layout_config=self.layout,
    
                item_info=deepcopy(self.config_item_info),
    
                as_files=False,
    
                yaml_already_loaded=True,
    
            )
    
            if self.randomize_counter_placement:
                shuffle_counters(self.env)
    
            self.player_name = str(0)
            self.env.add_player(self.player_name)
            self.player_id = list(self.env.players.keys())[0]
    
            info = {}
            obs = self.get_observation()
    
            self.prev_score = 0
    
            return obs, info
    
        def get_observation(self):
            if self.use_rgb_obs:
    
                obs = self.get_env_img()
    
            else:
                obs = self.get_vector_state()
            return obs
    
        def render(self):
    
    
        def close(self):
            pass
    
    
        def get_env_img(self):
    
            state = self.env.get_json_state(player_id=self.player_id)
            json_dict = json.loads(state)
    
            observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8)
            return observation
    
    
        def get_vector_state(self):
    
            obs = self.converter.convert_state_to_observation(self.env)
    
            return obs
    
        def sample_random_action(self):
            act = self.action_space.sample()
            return act