Skip to content
Snippets Groups Projects

Resolve "gym env"

Merged Fabian Heinrich requested to merge 86-gym-env into main
Compare and Show latest version
1 file
+ 27
9
Compare changes
  • Side-by-side
  • Inline
import json
import random
import time
from copy import copy, deepcopy
from copy import deepcopy
from datetime import timedelta
from enum import Enum
from pathlib import Path
@@ -16,6 +17,7 @@ from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from overcooked_simulator import ROOT_DIR
from overcooked_simulator.counters import Counter
from overcooked_simulator.gui_2d_vis.drawing import Visualizer
from overcooked_simulator.overcooked_environment import (
Environment,
@@ -117,7 +119,6 @@ def get_env_action(player_id, simple_action, duration):
print("FAIL", simple_action)
environment_config_path = ROOT_DIR / "game_content" / "environment_config_rl.yaml"
layout_path: Path = ROOT_DIR / "game_content" / "layouts" / "rl.layout"
item_info_path = ROOT_DIR / "game_content" / "item_info_rl.yaml"
@@ -131,11 +132,12 @@ with open(ROOT_DIR / "gui_2d_vis" / "visualization_rl.yaml", "r") as file:
visualization_config = yaml.safe_load(file)
vanilla_env: Environment = Environment(
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False
)
env_config=environment_config,
layout_config=layout,
item_info=item_info,
as_files=False,
)
class EnvGymWrapper(Env):
"""Should enable this:
@@ -150,6 +152,15 @@ class EnvGymWrapper(Env):
self.gridsize = 20
self.env = deepcopy(vanilla_env)
sample_counter = []
for counter in self.env.counters:
if counter.__class__ != Counter:
sample_counter.append(counter)
new_counter_pos = [c.pos for c in sample_counter]
random.shuffle(new_counter_pos)
for counter, new_pos in zip(sample_counter, new_counter_pos):
counter.pos = new_pos
self.env.vector_state_generation = self.env.setup_vectorization()
self.visualizer: Visualizer = Visualizer(config=visualization_config)
self.player_name = str(0)
@@ -195,7 +206,6 @@ class EnvGymWrapper(Env):
# observation = self.get_env_img(self.gridsize)
observation = self.get_vector_state()
reward = self.env.score - self.prev_score
self.prev_score = self.env.score
terminated = self.env.game_ended
@@ -214,6 +224,15 @@ class EnvGymWrapper(Env):
# )
self.env = deepcopy(vanilla_env)
sample_counter = []
for counter in self.env.counters:
if counter.__class__ != Counter:
sample_counter.append(counter)
new_counter_pos = [c.pos for c in sample_counter]
random.shuffle(new_counter_pos)
for counter, new_pos in zip(sample_counter, new_counter_pos):
counter.pos = new_pos
self.env.vector_state_generation = self.env.setup_vectorization()
self.player_name = str(0)
self.env.add_player(self.player_name)
@@ -256,7 +275,6 @@ class EnvGymWrapper(Env):
# return np.random.randint(len(self.action_space_map))
def main():
rl_agent_checkpoints = Path("./rl_agent_checkpoints")
rl_agent_checkpoints.mkdir(exist_ok=True)
Loading