From 7d86dd706bc1b3e6f6583f374b66114295e6be54 Mon Sep 17 00:00:00 2001 From: Christoph Kowalski <christoph.kowalski@titus-research.eu> Date: Wed, 16 Oct 2024 11:03:16 +0200 Subject: [PATCH] Adapted docstring --- .../reinforcement_learning/gym_env.py | 15 +++------------ .../obs_converter/advanced_converter_array.py | 2 +- .../reinforcement_learning/train_single_agent.py | 1 - 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/cooperative_cuisine/reinforcement_learning/gym_env.py b/cooperative_cuisine/reinforcement_learning/gym_env.py index 401654f2..3c77b6d7 100644 --- a/cooperative_cuisine/reinforcement_learning/gym_env.py +++ b/cooperative_cuisine/reinforcement_learning/gym_env.py @@ -36,7 +36,6 @@ class SimpleActionSpace(Enum): def get_env_action(player_id, simple_action, duration): - """ Args: @@ -103,7 +102,6 @@ visualizer.set_grid_size(40) def shuffle_counters(env): - """ Shuffles the counters of an environment Args: @@ -160,7 +158,8 @@ class EnvGymWrapper(Env): config_env = OmegaConf.to_container(config.environment, resolve=True) config_item_info = OmegaConf.to_container(config.item_info, resolve=True) for val in config_env['hook_callbacks']: - config_env['hook_callbacks'][val]["callback_class"] = instantiate(config_env['hook_callbacks'][val]["callback_class"]) + config_env['hook_callbacks'][val]["callback_class"] = instantiate( + config_env['hook_callbacks'][val]["callback_class"]) config_env["orders"]["order_gen_class"] = instantiate(config_env["orders"]["order_generator"]) self.config_env = config_env self.config_item_info = config_item_info @@ -223,6 +222,7 @@ class EnvGymWrapper(Env): and additional information """ # this is simply a work-around to enable no action which is necessary for the play_gym.py + # but not for the rl agent if action == 8: observation = self.get_observation() reward = self.env.score - self.prev_score @@ -240,19 +240,14 @@ class EnvGymWrapper(Env): self.env.step( timedelta(seconds=self.global_step_time / self.in_between_steps) ) - observation = self.get_observation() - reward = self.env.score - self.prev_score self.prev_score = self.env.score - if reward > 0: print("- - - - - - - - - - - - - - - - SCORED", reward) - terminated = self.env.game_ended truncated = self.env.game_ended info = {} - return observation, reward, terminated, truncated, info def reset(self, seed=None, options=None): @@ -272,16 +267,12 @@ class EnvGymWrapper(Env): if self.randomize_counter_placement: shuffle_counters(self.env) - self.player_name = str(0) self.env.add_player(self.player_name) self.player_id = list(self.env.players.keys())[0] - info = {} obs = self.get_observation() - self.prev_score = 0 - return obs, info def get_observation(self): diff --git a/cooperative_cuisine/reinforcement_learning/obs_converter/advanced_converter_array.py b/cooperative_cuisine/reinforcement_learning/obs_converter/advanced_converter_array.py index be7dbcad..e4bb5f65 100644 --- a/cooperative_cuisine/reinforcement_learning/obs_converter/advanced_converter_array.py +++ b/cooperative_cuisine/reinforcement_learning/obs_converter/advanced_converter_array.py @@ -52,7 +52,7 @@ class AdvancedStateConverterArray(StateToObservationConverter): setup is chosen. - Returns: An encoding for the environment state that is not onehot + Returns: An encoding for the environment state """ if player is not None: diff --git a/cooperative_cuisine/reinforcement_learning/train_single_agent.py b/cooperative_cuisine/reinforcement_learning/train_single_agent.py index d78e0f74..0b9647be 100644 --- a/cooperative_cuisine/reinforcement_learning/train_single_agent.py +++ b/cooperative_cuisine/reinforcement_learning/train_single_agent.py @@ -27,7 +27,6 @@ def main(cfg: DictConfig): config: dict[str, Any] = OmegaConf.to_container(cfg.model, resolve=True) env_info: dict[str, Any] = OmegaConf.to_container(cfg.environment, resolve=True) debug: bool = additional_configs["debug_mode"] - vec_env = additional_configs["vec_env"] number_envs_parallel = config["number_envs_parallel"] model_class = instantiate(cfg.model.model_type) -- GitLab