import json import random from abc import abstractmethod import uuid from copy import deepcopy from datetime import timedelta from enum import Enum from pathlib import Path from uuid import uuid4 import cv2 import numpy as np import yaml from gymnasium import spaces, Env from hydra.utils import instantiate from omegaconf import OmegaConf from cooperative_cuisine import ROOT_DIR from cooperative_cuisine.action import ActionType, InterActionData, Action from cooperative_cuisine.counters import Counter from cooperative_cuisine.environment import ( Environment, ) from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags class SimpleActionSpace(Enum): """Enumeration of actions for simple action spaces for an RL agent.""" Up = "Up" Down = "Down" Left = "Left" Right = "Right" Interact = "Interact" Put = "Put" def get_env_action(player_id, simple_action, duration): """ Args: player_id: id of the player simple_action: an action in the form of a SimpleActionSpace duration: for how long an action should be conducted Returns: a concrete action """ match simple_action: case SimpleActionSpace.Up: return Action( player_id, ActionType.MOVEMENT, np.array([0, -1]), duration, ) case SimpleActionSpace.Left: return Action( player_id, ActionType.MOVEMENT, np.array([-1, 0]), duration, ) case SimpleActionSpace.Down: return Action( player_id, ActionType.MOVEMENT, np.array([0, 1]), duration, ) case SimpleActionSpace.Right: return Action( player_id, ActionType.MOVEMENT, np.array([1, 0]), duration, ) case SimpleActionSpace.Put: return Action( player_id, ActionType.PICK_UP_DROP, InterActionData.START, duration, ) case SimpleActionSpace.Interact: return Action( player_id, ActionType.INTERACT, InterActionData.START, duration, ) with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file: visualization_config = yaml.safe_load(file) visualizer: Visualizer = Visualizer(config=visualization_config) visualizer.create_player_colors(1) visualizer.set_grid_size(40) def shuffle_counters(env): """ Shuffles the counters of an environment Args: env: the environment object """ sample_counter = [] other_counters = [] for counter in env.counters: if counter.__class__ != Counter: sample_counter.append(counter) else: other_counters.append(counter) new_counter_pos = [c.pos for c in sample_counter] random.shuffle(new_counter_pos) for counter, new_pos in zip(sample_counter, new_counter_pos): counter.pos = new_pos sample_counter.extend(other_counters) env.overwrite_counters(sample_counter) class StateToObservationConverter: """ Abstract definition of a class that gets and environment and outputs a state representation for rl """ @abstractmethod def setup(self, env, item_info): ... @abstractmethod def convert_state_to_observation(self, state) -> np.ndarray: ... class EnvGymWrapper(Env): """Should enable this: observation, reward, terminated, truncated, info = env.step(action) """ metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10} def __init__(self, config): """ Initializes all necessary variables Args: config:gets the rl and environment configuration from hydra """ super().__init__() self.randomize_counter_placement = False self.use_rgb_obs = False # if False uses simple vectorized state self.full_vector_state = True config_env = OmegaConf.to_container(config.environment, resolve=True) config_item_info = OmegaConf.to_container(config.item_info, resolve=True) for val in config_env['hook_callbacks']: config_env['hook_callbacks'][val]["callback_class"] = instantiate(config_env['hook_callbacks'][val]["callback_class"]) config_env["orders"]["order_gen_class"] = instantiate(config_env["orders"]["order_generator"]) self.config_env = config_env self.config_item_info = config_item_info layout_file = config_env["layout_name"] layout_path: Path = ROOT_DIR / layout_file with open(layout_path, "r") as file: self.layout = file.read() self.env: Environment = Environment( env_config=deepcopy(config_env), layout_config=self.layout, item_info=deepcopy(config_item_info), as_files=False, yaml_already_loaded=True, env_name=uuid.uuid4().hex, ) if self.randomize_counter_placement: shuffle_counters(self.env) self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] self.visualizer = visualizer # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)} self.action_space_map = {} for idx, item in enumerate(SimpleActionSpace): self.action_space_map[idx] = item self.global_step_time = 1 self.in_between_steps = 1 self.action_space = spaces.Discrete(len(self.action_space_map)) self.seen_items = [] self.converter = instantiate(config.additional_configs.state_converter) # self.converter.setup could also get the item info config in order to get all the possible items. self.converter.setup(self.env, self.config_item_info) if hasattr(self.converter, "onehot"): self.onehot_state = self.converter.onehot else: self.onehot_state = 'onehot' in config.additional_configs.state_converter.lower() dummy_obs = self.get_observation() min_obs_val = -1 if not self.use_rgb_obs else 0 max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20 self.observation_space = spaces.Box( low=min_obs_val, high=max_obs_val, shape=dummy_obs.shape, dtype=np.uint8 if self.use_rgb_obs else int, ) self.last_obs = dummy_obs self.step_counter = 0 self.prev_score = 0 def step(self, action): """ takes one step in the environment and returns the observation, reward, info whether terminated, truncated and additional information """ # this is simply a work-around to enable no action which is necessary for the play_gym.py if action == 8: observation = self.get_observation() reward = self.env.score - self.prev_score terminated = self.env.game_ended truncated = self.env.game_ended info = {} return observation, reward, terminated, truncated, info simple_action = self.action_space_map[action] env_action = get_env_action( self.player_id, simple_action, self.global_step_time ) self.env.perform_action(env_action) for i in range(self.in_between_steps): self.env.step( timedelta(seconds=self.global_step_time / self.in_between_steps) ) observation = self.get_observation() reward = self.env.score - self.prev_score self.prev_score = self.env.score if reward > 0.6: print("- - - - - - - - - - - - - - - - SCORED", reward) terminated = self.env.game_ended truncated = self.env.game_ended info = {} return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): """ Resets the environment according to the configs """ if self.env.env_name in visualizer.surface_cache_dict: del visualizer.surface_cache_dict[self.env.env_name] self.env: Environment = Environment( env_config=deepcopy(self.config_env), layout_config=self.layout, item_info=deepcopy(self.config_item_info), as_files=False, env_name=uuid.uuid4().hex, yaml_already_loaded=True, ) if self.randomize_counter_placement: shuffle_counters(self.env) self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] info = {} obs = self.get_observation() self.prev_score = 0 return obs, info def get_observation(self): if self.use_rgb_obs: obs = self.get_env_img() else: obs = self.get_vector_state() return obs def render(self): return self.get_env_img() def close(self): pass def get_env_img(self): state = self.env.get_json_state(player_id=self.player_id) json_dict = json.loads(state) observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8) return observation def get_vector_state(self): obs = self.converter.convert_state_to_observation(self.env) return obs def sample_random_action(self): act = self.action_space.sample() return act