from pathlib import Path from typing import Any import wandb from omegaconf import DictConfig, OmegaConf from stable_baselines3 import A2C from stable_baselines3 import DQN from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecVideoRecorder from wandb.integration.sb3 import WandbCallback from gym_env import EnvGymWrapper import hydra from hydra.utils import instantiate @hydra.main(version_base="1.3", config_path="config", config_name="rl_config") def main(cfg: DictConfig): """ trains an agent from scratch and saves the model to the specified path All configs are managed with hydra. """ additional_configs: dict[str, Any] = OmegaConf.to_container(cfg.additional_configs, resolve=True) rl_logs: Path = Path(additional_configs["log_path"]) rl_logs.mkdir(exist_ok=True) rl_agent_checkpoints: Path = rl_logs / Path(additional_configs["checkpoint_path"]) rl_agent_checkpoints.mkdir(exist_ok=True) config: dict[str, Any] = OmegaConf.to_container(cfg.model, resolve=True) env_info: dict[str, Any] = OmegaConf.to_container(cfg.environment, resolve=True) debug: bool = additional_configs["debug_mode"] vec_env = additional_configs["vec_env"] number_envs_parallel = config["number_envs_parallel"] model_class = instantiate(cfg.model.model_type) data_to_log=dict(config, **env_info) if vec_env: env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel) else: env = EnvGymWrapper(cfg) env.render_mode = additional_configs["render_mode"] if not debug: # also upload the environment config to W&B and all stable baselines3 hyperparams run = wandb.init( project=additional_configs["project_name"], config=data_to_log, sync_tensorboard=additional_configs["sync_tensorboard"], # auto-upload sb3's tensorboard metrics monitor_gym=additional_configs["monitor_gym"], dir= additional_configs["log_path"] # save_code=True, # optional ) env = VecVideoRecorder( env, additional_configs["video_save_path"] + run.id, record_video_trigger=lambda x: x % additional_configs["record_video_trigger"] == 0, video_length= additional_configs["video_length"], ) model_save_name = additional_configs["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"] model_save_path = rl_agent_checkpoints / model_save_name filtered_config = {k: v for k, v in config.items() if k not in ["env_id", "policy_type", "model_name", "model_type", "model_type_inference" ,"total_timesteps", "number_envs_parallel"] and v != 'None'} model = model_class( env=env, **filtered_config ) if debug: model.learn( total_timesteps=config["total_timesteps"], log_interval=1, progress_bar=additional_configs["progress_bar"], ) else: checkpoint_callback = CheckpointCallback( save_freq=additional_configs["save_freq"], save_path=additional_configs["save_path_callback"], name_prefix=additional_configs["name_prefix_callback"], save_replay_buffer=additional_configs["save_replay_buffer"], save_vecnormalize=additional_configs["save_vecnormalize"], ) wandb_callback = WandbCallback( model_save_path=additional_configs["video_save_path"] + run.id, verbose=0, ) callback = CallbackList([checkpoint_callback, wandb_callback]) model.learn( total_timesteps=config["total_timesteps"], callback=callback, log_interval=1, progress_bar=additional_configs["progress_bar"], ) run.finish() model.save(model_save_path) del model if __name__ == "__main__": main()