Skip to content
Snippets Groups Projects
Commit d75c2a5a authored by fheinrich's avatar fheinrich
Browse files

WandB integration

parent 39266378
No related branches found
No related tags found
1 merge request!52Resolve "gym env"
Pipeline #45685 passed
...@@ -5,7 +5,7 @@ plates: ...@@ -5,7 +5,7 @@ plates:
# range of seconds until the dirty plate arrives. # range of seconds until the dirty plate arrives.
game: game:
time_limit_seconds: 300 time_limit_seconds: 660
meals: meals:
all: true all: true
......
...@@ -16,6 +16,8 @@ from overcooked_simulator.overcooked_environment import ( ...@@ -16,6 +16,8 @@ from overcooked_simulator.overcooked_environment import (
ActionType, ActionType,
InterActionData, InterActionData,
) )
import wandb
from wandb.integration.sb3 import WandbCallback
import gymnasium as gym import gymnasium as gym
import numpy as np import numpy as np
...@@ -179,7 +181,7 @@ class EnvGymWrapper(Env): ...@@ -179,7 +181,7 @@ class EnvGymWrapper(Env):
reward = -1 reward = -1
if self.env.order_and_score.score > self.prev_score and self.env.score != 0: if self.env.order_and_score.score > self.prev_score and self.env.score != 0:
self.prev_score = self.env self.prev_score = self.env
reward = 100 reward = 200
elif self.env.order_and_score.score < self.prev_score: elif self.env.order_and_score.score < self.prev_score:
self.prev_score = 0 self.prev_score = 0
reward = 0 reward = 0
...@@ -249,11 +251,29 @@ class EnvGymWrapper(Env): ...@@ -249,11 +251,29 @@ class EnvGymWrapper(Env):
def main(): def main():
config = {
"policy_type": "CnnPolicy",
"total_timesteps": 100000,
"env_id": "overcooked",
}
run = wandb.init(
project="overcooked",
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
# monitor_gym=True, # auto-upload the videos of agents playing the game
# save_code=True, # optional
)
env = EnvGymWrapper() env = EnvGymWrapper()
check_env(env) check_env(env)
from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.env_util import make_vec_env
vec_env = make_vec_env(EnvGymWrapper, n_envs=4) vec_env = make_vec_env(EnvGymWrapper, n_envs=32)
# print("start") # print("start")
...@@ -271,17 +291,31 @@ def main(): ...@@ -271,17 +291,31 @@ def main():
from stable_baselines3 import PPO from stable_baselines3 import PPO
RL_CLASS = PPO # RL_CLASS = PPO
model = PPO(config["policy_type"], vec_env, verbose=1, tensorboard_log=f"runs/{run.id}")
# model = PPO("CnnPolicy", vec_env, verbose=1)
model = RL_CLASS("CnnPolicy", vec_env, verbose=1) model.learn(
total_timesteps=config["total_timesteps"],
callback=WandbCallback(
model_save_path=f"models/{run.id}",
verbose=0,
),
log_interval=1,
progress_bar=True
)
# model.learn(total_timesteps=100000, log_interval=1, progress_bar=True)
run.finish()
model.learn(total_timesteps=50000, log_interval=1, progress_bar=True)
model.save("oc") model.save("oc")
del model # remove to demonstrate saving and loading del model # remove to demonstrate saving and loading
print("LEARNING DONE.") print("LEARNING DONE.")
model = RL_CLASS.load("oc") model = PPO.load("oc")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment