Skip to content
Snippets Groups Projects
train_single_agent.py 3.28 KiB
Newer Older
  • Learn to ignore specific revisions
  • from pathlib import Path
    
    import wandb
    from omegaconf import DictConfig, OmegaConf
    from stable_baselines3 import A2C
    from stable_baselines3 import DQN
    from stable_baselines3 import PPO
    from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
    from stable_baselines3.common.env_util import make_vec_env
    from stable_baselines3.common.vec_env import VecVideoRecorder
    from wandb.integration.sb3 import WandbCallback
    
    from gym_env import EnvGymWrapper
    import hydra
    
    
    @hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
    def main(cfg: DictConfig):
        rl_logs = Path("logs/reinforcement_learning")
        rl_logs.mkdir(exist_ok=True)
        rl_agent_checkpoints = Path("logs/reinforcement_learning/rl_agent_checkpoints")
        rl_agent_checkpoints.mkdir(exist_ok=True)
        config = OmegaConf.to_container(cfg.model, resolve=True)
        debug = False
        do_training = True
        vec_env = True
        models = {"A2C": A2C, "DQN": DQN, "PPO": PPO}
        number_envs_parallel = config["number_envs_parallel"]
        model_class = models[config["model_type"]]
        if vec_env:
            env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel)
        else:
            env = EnvGymWrapper(cfg)
    
        env.render_mode = "rgb_array"
        if not debug:
            # also upload the environment config to W&B and all stable baselines3 hyperparams
            run = wandb.init(
                project="overcooked",
                config=config,
                sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
                monitor_gym=True,
                dir="logs/reinforcement_learning"
                # save_code=True,  # optional
            )
    
            env = VecVideoRecorder(
                env,
                f"logs/reinforcement_learning/videos/{run.id}",
                record_video_trigger=lambda x: x % 200_000 == 0,
                video_length=300,
            )
    
        model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
    
        if do_training:
            model = model_class(
                config["policy_type"],
                env,
                verbose=1,
                tensorboard_log=f"logs/reinforcement_learning/runs/{0}",
                device="cpu"
                # n_steps=2048,
                # n_epochs=10,
            )
    
            # Maybe Hydra Instatiate here to avoid hard coding the possible classes
            if debug:
                model.learn(
                    total_timesteps=config["total_timesteps"],
                    log_interval=1,
                    progress_bar=True,
                )
            else:
                checkpoint_callback = CheckpointCallback(
                    save_freq=50_000,
                    save_path="logs",
                    name_prefix="rl_model",
                    save_replay_buffer=True,
                    save_vecnormalize=True,
                )
                wandb_callback = WandbCallback(
                    model_save_path=f"logs/reinforcement_learning/models/{run.id}",
                    verbose=0,
                )
    
                callback = CallbackList([checkpoint_callback, wandb_callback])
                model.learn(
                    total_timesteps=config["total_timesteps"],
                    callback=callback,
                    log_interval=1,
                    progress_bar=True,
                )
                run.finish()
            model.save(model_save_path)
    
            del model
        print("LEARNING DONE.")
    
    
    if __name__ == "__main__":
        main()