Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from pathlib import Path
import wandb
from omegaconf import DictConfig, OmegaConf
from stable_baselines3 import A2C
from stable_baselines3 import DQN
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder
from wandb.integration.sb3 import WandbCallback
from gym_env import EnvGymWrapper
import hydra
@hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
def main(cfg: DictConfig):
rl_logs = Path("logs/reinforcement_learning")
rl_logs.mkdir(exist_ok=True)
rl_agent_checkpoints = Path("logs/reinforcement_learning/rl_agent_checkpoints")
rl_agent_checkpoints.mkdir(exist_ok=True)
config = OmegaConf.to_container(cfg.model, resolve=True)
debug = False
do_training = True
vec_env = True
models = {"A2C": A2C, "DQN": DQN, "PPO": PPO}
number_envs_parallel = config["number_envs_parallel"]
model_class = models[config["model_type"]]
if vec_env:
env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel)
else:
env = EnvGymWrapper(cfg)
env.render_mode = "rgb_array"
if not debug:
# also upload the environment config to W&B and all stable baselines3 hyperparams
run = wandb.init(
project="overcooked",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True,
dir="logs/reinforcement_learning"
# save_code=True, # optional
)
env = VecVideoRecorder(
env,
f"logs/reinforcement_learning/videos/{run.id}",
record_video_trigger=lambda x: x % 200_000 == 0,
video_length=300,
)
model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
if do_training:
model = model_class(
config["policy_type"],
env,
verbose=1,
tensorboard_log=f"logs/reinforcement_learning/runs/{0}",
device="cpu"
# n_steps=2048,
# n_epochs=10,
)
# Maybe Hydra Instatiate here to avoid hard coding the possible classes
if debug:
model.learn(
total_timesteps=config["total_timesteps"],
log_interval=1,
progress_bar=True,
)
else:
checkpoint_callback = CheckpointCallback(
save_freq=50_000,
save_path="logs",
name_prefix="rl_model",
save_replay_buffer=True,
save_vecnormalize=True,
)
wandb_callback = WandbCallback(
model_save_path=f"logs/reinforcement_learning/models/{run.id}",
verbose=0,
)
callback = CallbackList([checkpoint_callback, wandb_callback])
model.learn(
total_timesteps=config["total_timesteps"],
callback=callback,
log_interval=1,
progress_bar=True,
)
run.finish()
model.save(model_save_path)
del model
print("LEARNING DONE.")
if __name__ == "__main__":
main()