Skip to content
Snippets Groups Projects
train_single_agent.py 3.02 KiB
Newer Older
  • Learn to ignore specific revisions
  • from pathlib import Path
    
    import wandb
    from omegaconf import DictConfig, OmegaConf
    from stable_baselines3 import A2C
    from stable_baselines3 import DQN
    from stable_baselines3 import PPO
    from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.vec_env import VecVideoRecorder
    from wandb.integration.sb3 import WandbCallback
    
    from gym_env import EnvGymWrapper
    import hydra
    
    
    @hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
    def main(cfg: DictConfig):
        rl_logs = Path("logs/reinforcement_learning")
        rl_logs.mkdir(exist_ok=True)
        rl_agent_checkpoints = Path("logs/reinforcement_learning/rl_agent_checkpoints")
        rl_agent_checkpoints.mkdir(exist_ok=True)
        config = OmegaConf.to_container(cfg.model, resolve=True)
        debug = False
        vec_env = True
        models = {"A2C": A2C, "DQN": DQN, "PPO": PPO}
        number_envs_parallel = config["number_envs_parallel"]
        model_class = models[config["model_type"]]
        if vec_env:
            env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel)
        else:
            env = EnvGymWrapper(cfg)
    
        env.render_mode = "rgb_array"
        if not debug:
            # also upload the environment config to W&B and all stable baselines3 hyperparams
            run = wandb.init(
                project="overcooked",
                config=config,
                sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
                monitor_gym=True,
                dir="logs/reinforcement_learning"
                # save_code=True,  # optional
            )
    
            env = VecVideoRecorder(
                env,
                f"logs/reinforcement_learning/videos/{run.id}",
                record_video_trigger=lambda x: x % 200_000 == 0,
                video_length=300,
            )
    
        model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
    
    
        filtered_config = {k: v for k, v in config.items() if
                           k not in ["env_id", "policy_type", "model_type", "total_timesteps", "number_envs_parallel"] and v != 'None'}
        model = model_class(
            env=env,
            **filtered_config
        )
        if debug:
            model.learn(
                total_timesteps=config["total_timesteps"],
                log_interval=1,
                progress_bar=True,
            )
        else:
            checkpoint_callback = CheckpointCallback(
                save_freq=50_000,
                save_path="logs",
                name_prefix="rl_model",
                save_replay_buffer=True,
                save_vecnormalize=True,
            )
            wandb_callback = WandbCallback(
                model_save_path=f"logs/reinforcement_learning/models/{run.id}",
                verbose=0,
    
            callback = CallbackList([checkpoint_callback, wandb_callback])
            model.learn(
                total_timesteps=config["total_timesteps"],
                callback=callback,
                log_interval=1,
                progress_bar=True,
            )
            run.finish()
        model.save(model_save_path)