Skip to content
Snippets Groups Projects
Commit ffc9571d authored by fheinrich's avatar fheinrich
Browse files

Larger layout, fixed prev score in reset

parent e9442fc1
No related branches found
No related tags found
1 merge request!52Resolve "gym env"
Pipeline #45806 passed
#X##
T__#
U__P
#CW#
#X###
T___#
#___#
U___P
#C#W#
......@@ -243,6 +243,8 @@ class EnvGymWrapper(Env):
info = {}
obs = self.get_observation()
self.prev_score = 0
return obs, info
def get_observation(self):
......@@ -288,14 +290,14 @@ def main():
config = {
"policy_type": "MlpPolicy",
"total_timesteps": 500_000, # hendric sagt eher so 300_000_000 schritte
"total_timesteps": 100_000, # hendric sagt eher so 300_000_000 schritte
"env_id": "overcooked",
}
debug = True
do_training = True
vec_env = True
number_envs_parallel = 64
number_envs_parallel = 8
model_classes = [A2C, DQN, PPO]
model_class = model_classes[2]
......@@ -307,11 +309,11 @@ def main():
model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
if do_training:
model = model_class(
config["policy_type"], env, verbose=1, tensorboard_log=f"runs/{0}"
)
if debug:
model.learn(
total_timesteps=config["total_timesteps"],
......@@ -335,9 +337,9 @@ def main():
save_vecnormalize=True,
)
wandb_callback = WandbCallback(
model_save_path=f"models/{run.id}",
verbose=0,
)
model_save_path=f"models/{run.id}",
verbose=0,
)
callback = CallbackList([checkpoint_callback, wandb_callback])
model.learn(
......@@ -360,6 +362,7 @@ def main():
time.sleep(1 / 10)
action, _states = model.predict(obs, deterministic=False)
obs, reward, terminated, truncated, info = env.step(int(action))
print(reward)
env.render()
if terminated or truncated:
obs, info = env.reset()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment