Skip to content
Snippets Groups Projects
Commit 9f856ab7 authored by Christoph Kowalski's avatar Christoph Kowalski
Browse files

Updated logging to wandb

parent b9f10399
No related branches found
No related tags found
2 merge requests!110V1.2.0 changes,!109SB3 RL with Hydra
Pipeline #60813 passed
......@@ -4,13 +4,13 @@ state_converter:
log_path: "logs/reinforcement_learning"
checkpoint_path: "rl_agent_checkpoints"
render_mode: "rgb_array"
project_name: "overcooked"
project_name: "overcooked_rl"
debug_mode: False
vec_env: True
sync_tensorboard: True # auto-upload sb3's tensorboard metrics
monitor_gym: True
video_save_path: "logs/reinforcement_learning/videos/"
record_video_trigger: 200_000
record_video_trigger: 20_000
video_length: 300
save_freq: 50_000
save_path_callback: "logs"
......
......@@ -28,10 +28,12 @@ def main(cfg: DictConfig):
rl_agent_checkpoints: Path = rl_logs / Path(additional_configs["checkpoint_path"])
rl_agent_checkpoints.mkdir(exist_ok=True)
config: dict[str, Any] = OmegaConf.to_container(cfg.model, resolve=True)
env_info: dict[str, Any] = OmegaConf.to_container(cfg.environment, resolve=True)
debug: bool = additional_configs["debug_mode"]
vec_env = additional_configs["vec_env"]
number_envs_parallel = config["number_envs_parallel"]
model_class = instantiate(cfg.model.model_type)
data_to_log=dict(config, **env_info)
if vec_env:
env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel)
else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment