from pathlib import Path import wandb from omegaconf import DictConfig, OmegaConf from stable_baselines3 import A2C from stable_baselines3 import DQN from stable_baselines3 import PPO from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.vec_env import VecVideoRecorder from wandb.integration.sb3 import WandbCallback from gym_env import EnvGymWrapper import hydra @hydra.main(version_base="1.3", config_path="config", config_name="rl_config") def main(cfg: DictConfig): rl_logs = Path("logs/reinforcement_learning") rl_logs.mkdir(exist_ok=True) rl_agent_checkpoints = Path("logs/reinforcement_learning/rl_agent_checkpoints") rl_agent_checkpoints.mkdir(exist_ok=True) config = OmegaConf.to_container(cfg.model, resolve=True) debug = False vec_env = True models = {"A2C": A2C, "DQN": DQN, "PPO": PPO} number_envs_parallel = config["number_envs_parallel"] model_class = models[config["model_type"]] if vec_env: env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel) else: env = EnvGymWrapper(cfg) env.render_mode = "rgb_array" if not debug: # also upload the environment config to W&B and all stable baselines3 hyperparams run = wandb.init( project="overcooked", config=config, sync_tensorboard=True, # auto-upload sb3's tensorboard metrics monitor_gym=True, dir="logs/reinforcement_learning" # save_code=True, # optional ) env = VecVideoRecorder( env, f"logs/reinforcement_learning/videos/{run.id}", record_video_trigger=lambda x: x % 200_000 == 0, video_length=300, ) model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}" filtered_config = {k: v for k, v in config.items() if k not in ["env_id", "policy_type", "model_type", "total_timesteps", "number_envs_parallel"] and v != 'None'} model = model_class( env=env, **filtered_config ) if debug: model.learn( total_timesteps=config["total_timesteps"], log_interval=1, progress_bar=True, ) else: checkpoint_callback = CheckpointCallback( save_freq=50_000, save_path="logs", name_prefix="rl_model", save_replay_buffer=True, save_vecnormalize=True, ) wandb_callback = WandbCallback( model_save_path=f"logs/reinforcement_learning/models/{run.id}", verbose=0, ) callback = CallbackList([checkpoint_callback, wandb_callback]) model.learn( total_timesteps=config["total_timesteps"], callback=callback, log_interval=1, progress_bar=True, ) run.finish() model.save(model_save_path) del model if __name__ == "__main__": main()