Skip to content
Snippets Groups Projects
train_single_agent.py 3.68 KiB
Newer Older
  • Learn to ignore specific revisions
  • from pathlib import Path
    
    import wandb
    from omegaconf import DictConfig, OmegaConf
    from stable_baselines3 import A2C
    from stable_baselines3 import DQN
    from stable_baselines3 import PPO
    from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.vec_env import VecVideoRecorder
    from wandb.integration.sb3 import WandbCallback
    
    from gym_env import EnvGymWrapper
    import hydra
    
    from hydra.utils import instantiate
    
    
    
    @hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
    def main(cfg: DictConfig):
    
        additional_configs = OmegaConf.to_container(cfg.additional_configs, resolve=True)
        rl_logs = Path(additional_configs["log_path"])
    
        rl_logs.mkdir(exist_ok=True)
    
        rl_agent_checkpoints = rl_logs / Path(additional_configs["checkpoint_path"])
    
        rl_agent_checkpoints.mkdir(exist_ok=True)
        config = OmegaConf.to_container(cfg.model, resolve=True)
    
        debug = additional_configs["debug_mode"]
        vec_env = additional_configs["vec_env"]
    
        number_envs_parallel = config["number_envs_parallel"]
    
        model_class = instantiate(cfg.model.model_type)
    
        if vec_env:
            env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel)
        else:
            env = EnvGymWrapper(cfg)
    
    
        env.render_mode = additional_configs["render_mode"]
    
        if not debug:
            # also upload the environment config to W&B and all stable baselines3 hyperparams
            run = wandb.init(
    
                project=additional_configs["project_name"],
    
                sync_tensorboard=additional_configs["sync_tensorboard"],  # auto-upload sb3's tensorboard metrics
                monitor_gym=additional_configs["monitor_gym"],
                dir= additional_configs["log_path"]
    
                # save_code=True,  # optional
            )
    
            env = VecVideoRecorder(
                env,
    
                additional_configs["video_save_path"] + run.id,
                record_video_trigger=lambda x: x % additional_configs["record_video_trigger"] == 0,
                video_length= additional_configs["video_length"],
    
        model_save_name = additional_configs["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"]
        model_save_path = rl_agent_checkpoints / model_save_name
    
                           k not in ["env_id", "policy_type", "model_name", "model_type", "model_type_inference" ,"total_timesteps", "number_envs_parallel"] and v != 'None'}
    
        model = model_class(
            env=env,
            **filtered_config
        )
        if debug:
            model.learn(
                total_timesteps=config["total_timesteps"],
                log_interval=1,
    
                progress_bar=additional_configs["progress_bar"],
    
                save_freq=additional_configs["save_freq"],
                save_path=additional_configs["save_path_callback"],
                name_prefix=additional_configs["name_prefix_callback"],
                save_replay_buffer=additional_configs["save_replay_buffer"],
                save_vecnormalize=additional_configs["save_vecnormalize"],
    
                model_save_path=additional_configs["video_save_path"] + run.id,
    
            callback = CallbackList([checkpoint_callback, wandb_callback])
            model.learn(
                total_timesteps=config["total_timesteps"],
                callback=callback,
                log_interval=1,
    
                progress_bar=additional_configs["progress_bar"],