Newer
Older
from pathlib import Path
from typing import Any
import wandb
from omegaconf import DictConfig, OmegaConf
from stable_baselines3 import A2C
from stable_baselines3 import DQN
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder
from wandb.integration.sb3 import WandbCallback
from gym_env import EnvGymWrapper
import hydra
from hydra.utils import instantiate
@hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
def main(cfg: DictConfig):
"""
trains an agent from scratch and saves the model to the specified path
All configs are managed with hydra.
"""
additional_configs: dict[str, Any] = OmegaConf.to_container(cfg.additional_configs, resolve=True)
rl_logs: Path = Path(additional_configs["log_path"])
rl_logs.mkdir(exist_ok=True)
rl_agent_checkpoints: Path = rl_logs / Path(additional_configs["checkpoint_path"])
rl_agent_checkpoints.mkdir(exist_ok=True)
config: dict[str, Any] = OmegaConf.to_container(cfg.model, resolve=True)
env_info: dict[str, Any] = OmegaConf.to_container(cfg.environment, resolve=True)
debug: bool = additional_configs["debug_mode"]
vec_env = additional_configs["vec_env"]
number_envs_parallel = config["number_envs_parallel"]
model_class = instantiate(cfg.model.model_type)
if vec_env:
env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel)
else:
env = EnvGymWrapper(cfg)
env.render_mode = additional_configs["render_mode"]
if not debug:
# also upload the environment config to W&B and all stable baselines3 hyperparams
run = wandb.init(
project=additional_configs["project_name"],
config=config,
sync_tensorboard=additional_configs["sync_tensorboard"], # auto-upload sb3's tensorboard metrics
monitor_gym=additional_configs["monitor_gym"],
dir= additional_configs["log_path"]
# save_code=True, # optional
)
env = VecVideoRecorder(
env,
additional_configs["video_save_path"] + run.id,
record_video_trigger=lambda x: x % additional_configs["record_video_trigger"] == 0,
video_length= additional_configs["video_length"],
model_save_name = additional_configs["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"]
model_save_path = rl_agent_checkpoints / model_save_name

Christoph Kowalski
committed
filtered_config = {k: v for k, v in config.items() if

Christoph Kowalski
committed
k not in ["env_id", "policy_type", "model_name", "model_type", "model_type_inference" ,"total_timesteps", "number_envs_parallel"] and v != 'None'}

Christoph Kowalski
committed
model = model_class(
env=env,
**filtered_config
)
if debug:
model.learn(
total_timesteps=config["total_timesteps"],
log_interval=1,
progress_bar=additional_configs["progress_bar"],

Christoph Kowalski
committed
)
else:

Christoph Kowalski
committed
checkpoint_callback = CheckpointCallback(
save_freq=additional_configs["save_freq"],
save_path=additional_configs["save_path_callback"],
name_prefix=additional_configs["name_prefix_callback"],
save_replay_buffer=additional_configs["save_replay_buffer"],
save_vecnormalize=additional_configs["save_vecnormalize"],

Christoph Kowalski
committed
)
wandb_callback = WandbCallback(
model_save_path=additional_configs["video_save_path"] + run.id,

Christoph Kowalski
committed
verbose=0,

Christoph Kowalski
committed
callback = CallbackList([checkpoint_callback, wandb_callback])
model.learn(
total_timesteps=config["total_timesteps"],
callback=callback,
log_interval=1,
progress_bar=additional_configs["progress_bar"],

Christoph Kowalski
committed
)
run.finish()
model.save(model_save_path)

Christoph Kowalski
committed
del model
if __name__ == "__main__":
main()