diff --git a/cooperative_cuisine/environment.py b/cooperative_cuisine/environment.py index f95ba69eb1ce9666d5ec77fa5b92794ede5df1de..10b17bf8ff1e74c0535c4c763c1ea2746cb97646 100644 --- a/cooperative_cuisine/environment.py +++ b/cooperative_cuisine/environment.py @@ -125,6 +125,7 @@ class Environment: as_files: bool = True, env_name: str = "cooperative_cuisine_1", seed: int = 56789223842348, + yaml_already_loaded: bool = False ): """Constructor of the Environment. @@ -163,10 +164,13 @@ class Environment: layout_config = layout_file.read() with open(item_info, "r") as file: item_info = file.read() + if not yaml_already_loaded: + self.environment_config: EnvironmentConfig = yaml.load( + env_config, Loader=yaml.Loader + ) + else: + self.environment_config = env_config - self.environment_config: EnvironmentConfig = yaml.load( - env_config, Loader=yaml.Loader - ) """The config of the environment. All environment specific attributes is configured here.""" self.environment_config["player_config"] = PlayerConfig( **( @@ -248,7 +252,8 @@ class Environment: ) ) - self.recipe_validation.update_plate_config(plate_config, self.environment_config["layout_chars"], self.layout_config) + self.recipe_validation.update_plate_config(plate_config, self.environment_config["layout_chars"], + self.layout_config) self.counter_factory: CounterFactory = CounterFactory( layout_chars_config=self.environment_config["layout_chars"], @@ -382,17 +387,17 @@ class Environment: Utility method to pass a reference to the serving window.""" return self.env_time - def load_item_info(self, item_info: str) -> dict[str, ItemInfo]: + def load_item_info(self, item_info: str | dict[str, ItemInfo] ) -> dict[str, ItemInfo]: """Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos.""" self.hook(ITEM_INFO_CONFIG, item_info_config=item_info) - item_lookup = yaml.safe_load(item_info) - for item_name in item_lookup: - item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name]) - - for item_name, item_info in item_lookup.items(): - if item_info.equipment: - item_info.equipment = item_lookup[item_info.equipment] - return item_lookup + if isinstance(item_info, str): + item_info = yaml.safe_load(item_info) + for item_name in item_info: + item_info[item_name] = ItemInfo(name=item_name, **item_info[item_name]) + for item_name, single_item_info in item_info.items(): + if single_item_info.equipment: + single_item_info.equipment = item_info[single_item_info.equipment] + return item_info def perform_action(self, action: Action): """Performs an action of a player in the environment. Maps different types of action inputs to the diff --git a/cooperative_cuisine/reinforcement_learning/config/additional_configs/additional_config_base.yaml b/cooperative_cuisine/reinforcement_learning/config/additional_configs/additional_config_base.yaml index 67320454899cb01519bb2f3afe578c436ca0f3c4..6cde1c2df4c909f98a1c3b433804cf261aa7bf6b 100644 --- a/cooperative_cuisine/reinforcement_learning/config/additional_configs/additional_config_base.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/additional_configs/additional_config_base.yaml @@ -1,4 +1,21 @@ -defaults: - - order_generator: random_order_generator +order_generator: "random_orders.yaml" # Here the filename of the converter should be given. The converter class needs to be called StateConverter and implement the abstract StateToObservationConverter class -state_converter: "base_converter_onehot" \ No newline at end of file +state_converter: + _target_: "cooperative_cuisine.reinforcement_learning.obs_converter.base_converter.BaseStateConverter" +log_path: "logs/reinforcement_learning" +checkpoint_path: "rl_agent_checkpoints" +render_mode: "rgb_array" +project_name: "overcooked" +debug_mode: False +vec_env: True +sync_tensorboard: True # auto-upload sb3's tensorboard metrics +monitor_gym: True +video_save_path: "logs/reinforcement_learning/videos/" +record_video_trigger: 200_000 +video_length: 300 +save_freq: 50_000 +save_path_callback: "logs" +name_prefix_callback: "rl_model" +save_replay_buffer: True +save_vecnormalize: True +progress_bar: True \ No newline at end of file diff --git a/cooperative_cuisine/reinforcement_learning/config/additional_configs/order_generator/determinstic_order_generator.yaml b/cooperative_cuisine/reinforcement_learning/config/additional_configs/order_generator/determinstic_order_generator.yaml deleted file mode 100644 index 5dc39916f4b970a64de4d62e6b89f566965d99ce..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/reinforcement_learning/config/additional_configs/order_generator/determinstic_order_generator.yaml +++ /dev/null @@ -1 +0,0 @@ -order_generator_type: "deterministic" \ No newline at end of file diff --git a/cooperative_cuisine/reinforcement_learning/config/additional_configs/order_generator/random_order_generator.yaml b/cooperative_cuisine/reinforcement_learning/config/additional_configs/order_generator/random_order_generator.yaml deleted file mode 100644 index 89bc3d2e3c5fbc3e6e3b228f1f211644382cfe75..0000000000000000000000000000000000000000 --- a/cooperative_cuisine/reinforcement_learning/config/additional_configs/order_generator/random_order_generator.yaml +++ /dev/null @@ -1 +0,0 @@ -order_generator_type: "random" \ No newline at end of file diff --git a/cooperative_cuisine/reinforcement_learning/config/model/A2C.yaml b/cooperative_cuisine/reinforcement_learning/config/model/A2C.yaml index 7e01b0fd6956523ad006c512bb9020ec72ff9fe2..d58419f728820e26851ffc14ade25dcc5f0e5173 100644 --- a/cooperative_cuisine/reinforcement_learning/config/model/A2C.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/model/A2C.yaml @@ -1,6 +1,9 @@ env_id: "overcooked" policy: "MlpPolicy" -model_type: "A2C" +model_name: "A2C" +model_type: + _partial_: true + _target_: stable_baselines3.A2C total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte number_envs_parallel: 64 learning_rate: 0.0007 diff --git a/cooperative_cuisine/reinforcement_learning/config/model/DQN.yaml b/cooperative_cuisine/reinforcement_learning/config/model/DQN.yaml index 9fb00ed1d0ef51ed156b4f3d77c51327e6f95fb8..f7e2d4d4e088002050da5e9e2b04bae9ded911cf 100644 --- a/cooperative_cuisine/reinforcement_learning/config/model/DQN.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/model/DQN.yaml @@ -1,6 +1,9 @@ env_id: "overcooked" policy: "MlpPolicy" -model_type: "DQN" +model_name: "DQN" +model_type: + _partial_: true + _target_: stable_baselines3.DQN total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte number_envs_parallel: 64 learning_rate: 0.0001 diff --git a/cooperative_cuisine/reinforcement_learning/config/model/PPO.yaml b/cooperative_cuisine/reinforcement_learning/config/model/PPO.yaml index f253155af0dd51b0f660dfb39e7f55a3053694c7..05a284489f7a9abab95c52c1ef1f96cb505eecf3 100644 --- a/cooperative_cuisine/reinforcement_learning/config/model/PPO.yaml +++ b/cooperative_cuisine/reinforcement_learning/config/model/PPO.yaml @@ -1,6 +1,9 @@ env_id: "overcooked" policy: "MlpPolicy" -model_type: "PPO" +model_name: "PPO" +model_type: + _partial_: true + _target_: stable_baselines3.PPO total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte number_envs_parallel: 64 learning_rate: 0.0003 diff --git a/cooperative_cuisine/reinforcement_learning/gym_env.py b/cooperative_cuisine/reinforcement_learning/gym_env.py index 1ca2828b74f2ec73c93cf020b432bac02462beb6..3fbaeab9bf9f130146c0cfcc6c89648ce5ea753b 100644 --- a/cooperative_cuisine/reinforcement_learning/gym_env.py +++ b/cooperative_cuisine/reinforcement_learning/gym_env.py @@ -20,6 +20,7 @@ from cooperative_cuisine.environment import ( Environment, ) from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer +from hydra.utils import instantiate class SimpleActionSpace(Enum): @@ -104,6 +105,11 @@ def shuffle_counters(env): class StateToObservationConverter: + ''' + + + + ''' @abstractmethod def setup(self, env): ... @@ -113,13 +119,6 @@ class StateToObservationConverter: ... -def get_converter(converter_name): - module_path = f"cooperative_cuisine.reinforcement_learning.obs_converter.{converter_name}" - module = importlib.import_module(module_path) - converter_class = getattr(module, "StateConverter") - return converter_class() - - class EnvGymWrapper(Env): """Should enable this: observation, reward, terminated, truncated, info = env.step(action) @@ -135,12 +134,10 @@ class EnvGymWrapper(Env): self.full_vector_state = True config_env = OmegaConf.to_container(config.environment, resolve=True) config_item_info = OmegaConf.to_container(config.item_info, resolve=True) - order_generator = config.additional_configs.order_generator.order_generator_type - order_file = order_generator + "_orders.yaml" - custom_config_path = ROOT_DIR / "reinforcement_learning" / "config" / order_file + order_generator = config.additional_configs.order_generator + custom_config_path = ROOT_DIR / "reinforcement_learning" / "config" / order_generator with open(custom_config_path, "r") as file: - custom_classes = file.read() - custom_classes = yaml.load(custom_classes, Loader=yaml.Loader) + custom_classes = yaml.load(file, Loader=yaml.Loader) for key, value in config_env['hook_callbacks'].items(): value['callback_class'] = custom_classes['callback_class'] config_env["orders"]["order_gen_class"] = custom_classes['order_gen_class'] @@ -174,16 +171,12 @@ class EnvGymWrapper(Env): self.action_space = spaces.Discrete(len(self.action_space_map)) self.seen_items = [] - self.converter = get_converter(config.additional_configs.state_converter) + self.converter = instantiate(config.additional_configs.state_converter) self.converter.setup(self.env) - - try: + if hasattr(self.converter, "onehot"): self.onehot_state = self.converter.onehot - except AttributeError: - if 'onehot' in config.additional_configs.state_converter.lower(): - self.onehot_state = True - else: - self.onehot_state = False + else: + self.onehot_state = 'onehot' in config.additional_configs.state_converter.lower() dummy_obs = self.get_observation() min_obs_val = -1 if not self.use_rgb_obs else 0 max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20 diff --git a/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter.py b/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter.py index 0837969929c328df3def7da02e49ad5d92ffc5fd..703a7d7a27a4cede1ee3b488fdd75a7ddc5808c3 100644 --- a/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter.py +++ b/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter.py @@ -7,7 +7,7 @@ from cooperative_cuisine.items import CookingEquipment from cooperative_cuisine.reinforcement_learning.gym_env import StateToObservationConverter -class StateConverter(StateToObservationConverter): +class BaseStateConverter(StateToObservationConverter): def __init__(self): self.onehot = False self.counter_list = [ diff --git a/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter_onehot.py b/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter_onehot.py index bbbe5a73ff6e20e57aeb394fc87d9c099d7d2a63..d3a7d877db6cc4786c948fc522b156dba19eb00b 100644 --- a/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter_onehot.py +++ b/cooperative_cuisine/reinforcement_learning/obs_converter/base_converter_onehot.py @@ -7,7 +7,7 @@ from cooperative_cuisine.items import CookingEquipment from cooperative_cuisine.reinforcement_learning.gym_env import StateToObservationConverter -class StateConverter(StateToObservationConverter): +class BaseStateConverterOnehot(StateToObservationConverter): def __init__(self): self.onehot = True self.grid_height = None diff --git a/cooperative_cuisine/reinforcement_learning/run_single_agent.py b/cooperative_cuisine/reinforcement_learning/run_single_agent.py index c25e7265f1f79ac370c70ed0edbdc07d08744ab0..d69068eb69bc025673d96ab666a65d109cd73230 100644 --- a/cooperative_cuisine/reinforcement_learning/run_single_agent.py +++ b/cooperative_cuisine/reinforcement_learning/run_single_agent.py @@ -1,19 +1,20 @@ import time import cv2 -from stable_baselines3 import DQN +from stable_baselines3 import DQN, A2C, PPO from gym_env import EnvGymWrapper import hydra from omegaconf import DictConfig, OmegaConf - - - +from hydra.utils import instantiate @hydra.main(version_base="1.3", config_path="config", config_name="rl_config") def main(cfg: DictConfig): + additional_config = OmegaConf.to_container(cfg.additional_configs, resolve=True) model_save_path = "logs/reinforcement_learning/rl_agent_checkpoints/overcooked_DQN.zip" - model_class = DQN + model_save_path = additional_config["log_path"] + "/" + additional_config["checkpoint_path"] +"/"+ additional_config[ + "project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"] + model_class = instantiate(cfg.model.model_type) model = model_class.load(model_save_path) env = EnvGymWrapper(cfg) diff --git a/cooperative_cuisine/reinforcement_learning/train_single_agent.py b/cooperative_cuisine/reinforcement_learning/train_single_agent.py index dd48d5c1481f6d152ef34137da44c9d57bd8bec8..a06f65213170ce87df2f432fa327cbedf52f4411 100644 --- a/cooperative_cuisine/reinforcement_learning/train_single_agent.py +++ b/cooperative_cuisine/reinforcement_learning/train_single_agent.py @@ -12,48 +12,49 @@ from wandb.integration.sb3 import WandbCallback from gym_env import EnvGymWrapper import hydra +from hydra.utils import instantiate @hydra.main(version_base="1.3", config_path="config", config_name="rl_config") def main(cfg: DictConfig): - rl_logs = Path("logs/reinforcement_learning") + additional_configs = OmegaConf.to_container(cfg.additional_configs, resolve=True) + rl_logs = Path(additional_configs["log_path"]) rl_logs.mkdir(exist_ok=True) - rl_agent_checkpoints = Path("logs/reinforcement_learning/rl_agent_checkpoints") + rl_agent_checkpoints = rl_logs / Path(additional_configs["checkpoint_path"]) rl_agent_checkpoints.mkdir(exist_ok=True) config = OmegaConf.to_container(cfg.model, resolve=True) - debug = False - vec_env = True - models = {"A2C": A2C, "DQN": DQN, "PPO": PPO} + debug = additional_configs["debug_mode"] + vec_env = additional_configs["vec_env"] number_envs_parallel = config["number_envs_parallel"] - model_class = models[config["model_type"]] + model_class = instantiate(cfg.model.model_type) if vec_env: env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel) else: env = EnvGymWrapper(cfg) - env.render_mode = "rgb_array" + env.render_mode = additional_configs["render_mode"] if not debug: # also upload the environment config to W&B and all stable baselines3 hyperparams run = wandb.init( - project="overcooked", + project=additional_configs["project_name"], config=config, - sync_tensorboard=True, # auto-upload sb3's tensorboard metrics - monitor_gym=True, - dir="logs/reinforcement_learning" + sync_tensorboard=additional_configs["sync_tensorboard"], # auto-upload sb3's tensorboard metrics + monitor_gym=additional_configs["monitor_gym"], + dir= additional_configs["log_path"] # save_code=True, # optional ) env = VecVideoRecorder( env, - f"logs/reinforcement_learning/videos/{run.id}", - record_video_trigger=lambda x: x % 200_000 == 0, - video_length=300, + additional_configs["video_save_path"] + run.id, + record_video_trigger=lambda x: x % additional_configs["record_video_trigger"] == 0, + video_length= additional_configs["video_length"], ) - model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" - + model_save_name = additional_configs["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"] + model_save_path = rl_agent_checkpoints / model_save_name filtered_config = {k: v for k, v in config.items() if - k not in ["env_id", "policy_type", "model_type", "total_timesteps", "number_envs_parallel"] and v != 'None'} + k not in ["env_id", "policy_type", "model_name", "model_type", "total_timesteps", "number_envs_parallel"] and v != 'None'} model = model_class( env=env, **filtered_config @@ -62,18 +63,19 @@ def main(cfg: DictConfig): model.learn( total_timesteps=config["total_timesteps"], log_interval=1, - progress_bar=True, + progress_bar=additional_configs["progress_bar"], ) else: + checkpoint_callback = CheckpointCallback( - save_freq=50_000, - save_path="logs", - name_prefix="rl_model", - save_replay_buffer=True, - save_vecnormalize=True, + save_freq=additional_configs["save_freq"], + save_path=additional_configs["save_path_callback"], + name_prefix=additional_configs["name_prefix_callback"], + save_replay_buffer=additional_configs["save_replay_buffer"], + save_vecnormalize=additional_configs["save_vecnormalize"], ) wandb_callback = WandbCallback( - model_save_path=f"logs/reinforcement_learning/models/{run.id}", + model_save_path=additional_configs["video_save_path"] + run.id, verbose=0, ) @@ -82,7 +84,7 @@ def main(cfg: DictConfig): total_timesteps=config["total_timesteps"], callback=callback, log_interval=1, - progress_bar=True, + progress_bar=additional_configs["progress_bar"], ) run.finish() model.save(model_save_path)