Skip to content
Snippets Groups Projects
Commit ffd5586f authored by Christoph Kowalski's avatar Christoph Kowalski
Browse files

Incorporated the feedback to make more use of the advantages of Hydra

parent b0407648
No related branches found
No related tags found
3 merge requests!110V1.2.0 changes,!102Fixed caching of recipe layouts. Ids were used in hash, which are generated...,!98Resolve "Restructure Reinforcement Learning files"
Pipeline #57748 passed
This commit is part of merge request !98. Comments created here will be created in the context of that merge request.
Showing
with 98 additions and 73 deletions
......@@ -125,6 +125,7 @@ class Environment:
as_files: bool = True,
env_name: str = "cooperative_cuisine_1",
seed: int = 56789223842348,
yaml_already_loaded: bool = False
):
"""Constructor of the Environment.
......@@ -163,10 +164,13 @@ class Environment:
layout_config = layout_file.read()
with open(item_info, "r") as file:
item_info = file.read()
if not yaml_already_loaded:
self.environment_config: EnvironmentConfig = yaml.load(
env_config, Loader=yaml.Loader
)
else:
self.environment_config = env_config
self.environment_config: EnvironmentConfig = yaml.load(
env_config, Loader=yaml.Loader
)
"""The config of the environment. All environment specific attributes is configured here."""
self.environment_config["player_config"] = PlayerConfig(
**(
......@@ -248,7 +252,8 @@ class Environment:
)
)
self.recipe_validation.update_plate_config(plate_config, self.environment_config["layout_chars"], self.layout_config)
self.recipe_validation.update_plate_config(plate_config, self.environment_config["layout_chars"],
self.layout_config)
self.counter_factory: CounterFactory = CounterFactory(
layout_chars_config=self.environment_config["layout_chars"],
......@@ -382,17 +387,17 @@ class Environment:
Utility method to pass a reference to the serving window."""
return self.env_time
def load_item_info(self, item_info: str) -> dict[str, ItemInfo]:
def load_item_info(self, item_info: str | dict[str, ItemInfo] ) -> dict[str, ItemInfo]:
"""Load `item_info.yml`, create ItemInfo classes and replace equipment strings with item infos."""
self.hook(ITEM_INFO_CONFIG, item_info_config=item_info)
item_lookup = yaml.safe_load(item_info)
for item_name in item_lookup:
item_lookup[item_name] = ItemInfo(name=item_name, **item_lookup[item_name])
for item_name, item_info in item_lookup.items():
if item_info.equipment:
item_info.equipment = item_lookup[item_info.equipment]
return item_lookup
if isinstance(item_info, str):
item_info = yaml.safe_load(item_info)
for item_name in item_info:
item_info[item_name] = ItemInfo(name=item_name, **item_info[item_name])
for item_name, single_item_info in item_info.items():
if single_item_info.equipment:
single_item_info.equipment = item_info[single_item_info.equipment]
return item_info
def perform_action(self, action: Action):
"""Performs an action of a player in the environment. Maps different types of action inputs to the
......
defaults:
- order_generator: random_order_generator
order_generator: "random_orders.yaml"
# Here the filename of the converter should be given. The converter class needs to be called StateConverter and implement the abstract StateToObservationConverter class
state_converter: "base_converter_onehot"
\ No newline at end of file
state_converter:
_target_: "cooperative_cuisine.reinforcement_learning.obs_converter.base_converter.BaseStateConverter"
log_path: "logs/reinforcement_learning"
checkpoint_path: "rl_agent_checkpoints"
render_mode: "rgb_array"
project_name: "overcooked"
debug_mode: False
vec_env: True
sync_tensorboard: True # auto-upload sb3's tensorboard metrics
monitor_gym: True
video_save_path: "logs/reinforcement_learning/videos/"
record_video_trigger: 200_000
video_length: 300
save_freq: 50_000
save_path_callback: "logs"
name_prefix_callback: "rl_model"
save_replay_buffer: True
save_vecnormalize: True
progress_bar: True
\ No newline at end of file
order_generator_type: "deterministic"
\ No newline at end of file
order_generator_type: "random"
\ No newline at end of file
env_id: "overcooked"
policy: "MlpPolicy"
model_type: "A2C"
model_name: "A2C"
model_type:
_partial_: true
_target_: stable_baselines3.A2C
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64
learning_rate: 0.0007
......
env_id: "overcooked"
policy: "MlpPolicy"
model_type: "DQN"
model_name: "DQN"
model_type:
_partial_: true
_target_: stable_baselines3.DQN
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64
learning_rate: 0.0001
......
env_id: "overcooked"
policy: "MlpPolicy"
model_type: "PPO"
model_name: "PPO"
model_type:
_partial_: true
_target_: stable_baselines3.PPO
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64
learning_rate: 0.0003
......
......@@ -20,6 +20,7 @@ from cooperative_cuisine.environment import (
Environment,
)
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer
from hydra.utils import instantiate
class SimpleActionSpace(Enum):
......@@ -104,6 +105,11 @@ def shuffle_counters(env):
class StateToObservationConverter:
'''
'''
@abstractmethod
def setup(self, env):
...
......@@ -113,13 +119,6 @@ class StateToObservationConverter:
...
def get_converter(converter_name):
module_path = f"cooperative_cuisine.reinforcement_learning.obs_converter.{converter_name}"
module = importlib.import_module(module_path)
converter_class = getattr(module, "StateConverter")
return converter_class()
class EnvGymWrapper(Env):
"""Should enable this:
observation, reward, terminated, truncated, info = env.step(action)
......@@ -135,12 +134,10 @@ class EnvGymWrapper(Env):
self.full_vector_state = True
config_env = OmegaConf.to_container(config.environment, resolve=True)
config_item_info = OmegaConf.to_container(config.item_info, resolve=True)
order_generator = config.additional_configs.order_generator.order_generator_type
order_file = order_generator + "_orders.yaml"
custom_config_path = ROOT_DIR / "reinforcement_learning" / "config" / order_file
order_generator = config.additional_configs.order_generator
custom_config_path = ROOT_DIR / "reinforcement_learning" / "config" / order_generator
with open(custom_config_path, "r") as file:
custom_classes = file.read()
custom_classes = yaml.load(custom_classes, Loader=yaml.Loader)
custom_classes = yaml.load(file, Loader=yaml.Loader)
for key, value in config_env['hook_callbacks'].items():
value['callback_class'] = custom_classes['callback_class']
config_env["orders"]["order_gen_class"] = custom_classes['order_gen_class']
......@@ -174,16 +171,12 @@ class EnvGymWrapper(Env):
self.action_space = spaces.Discrete(len(self.action_space_map))
self.seen_items = []
self.converter = get_converter(config.additional_configs.state_converter)
self.converter = instantiate(config.additional_configs.state_converter)
self.converter.setup(self.env)
try:
if hasattr(self.converter, "onehot"):
self.onehot_state = self.converter.onehot
except AttributeError:
if 'onehot' in config.additional_configs.state_converter.lower():
self.onehot_state = True
else:
self.onehot_state = False
else:
self.onehot_state = 'onehot' in config.additional_configs.state_converter.lower()
dummy_obs = self.get_observation()
min_obs_val = -1 if not self.use_rgb_obs else 0
max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
......
......@@ -7,7 +7,7 @@ from cooperative_cuisine.items import CookingEquipment
from cooperative_cuisine.reinforcement_learning.gym_env import StateToObservationConverter
class StateConverter(StateToObservationConverter):
class BaseStateConverter(StateToObservationConverter):
def __init__(self):
self.onehot = False
self.counter_list = [
......
......@@ -7,7 +7,7 @@ from cooperative_cuisine.items import CookingEquipment
from cooperative_cuisine.reinforcement_learning.gym_env import StateToObservationConverter
class StateConverter(StateToObservationConverter):
class BaseStateConverterOnehot(StateToObservationConverter):
def __init__(self):
self.onehot = True
self.grid_height = None
......
import time
import cv2
from stable_baselines3 import DQN
from stable_baselines3 import DQN, A2C, PPO
from gym_env import EnvGymWrapper
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
@hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
def main(cfg: DictConfig):
additional_config = OmegaConf.to_container(cfg.additional_configs, resolve=True)
model_save_path = "logs/reinforcement_learning/rl_agent_checkpoints/overcooked_DQN.zip"
model_class = DQN
model_save_path = additional_config["log_path"] + "/" + additional_config["checkpoint_path"] +"/"+ additional_config[
"project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"]
model_class = instantiate(cfg.model.model_type)
model = model_class.load(model_save_path)
env = EnvGymWrapper(cfg)
......
......@@ -12,48 +12,49 @@ from wandb.integration.sb3 import WandbCallback
from gym_env import EnvGymWrapper
import hydra
from hydra.utils import instantiate
@hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
def main(cfg: DictConfig):
rl_logs = Path("logs/reinforcement_learning")
additional_configs = OmegaConf.to_container(cfg.additional_configs, resolve=True)
rl_logs = Path(additional_configs["log_path"])
rl_logs.mkdir(exist_ok=True)
rl_agent_checkpoints = Path("logs/reinforcement_learning/rl_agent_checkpoints")
rl_agent_checkpoints = rl_logs / Path(additional_configs["checkpoint_path"])
rl_agent_checkpoints.mkdir(exist_ok=True)
config = OmegaConf.to_container(cfg.model, resolve=True)
debug = False
vec_env = True
models = {"A2C": A2C, "DQN": DQN, "PPO": PPO}
debug = additional_configs["debug_mode"]
vec_env = additional_configs["vec_env"]
number_envs_parallel = config["number_envs_parallel"]
model_class = models[config["model_type"]]
model_class = instantiate(cfg.model.model_type)
if vec_env:
env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel)
else:
env = EnvGymWrapper(cfg)
env.render_mode = "rgb_array"
env.render_mode = additional_configs["render_mode"]
if not debug:
# also upload the environment config to W&B and all stable baselines3 hyperparams
run = wandb.init(
project="overcooked",
project=additional_configs["project_name"],
config=config,
sync_tensorboard=True, # auto-upload sb3's tensorboard metrics
monitor_gym=True,
dir="logs/reinforcement_learning"
sync_tensorboard=additional_configs["sync_tensorboard"], # auto-upload sb3's tensorboard metrics
monitor_gym=additional_configs["monitor_gym"],
dir= additional_configs["log_path"]
# save_code=True, # optional
)
env = VecVideoRecorder(
env,
f"logs/reinforcement_learning/videos/{run.id}",
record_video_trigger=lambda x: x % 200_000 == 0,
video_length=300,
additional_configs["video_save_path"] + run.id,
record_video_trigger=lambda x: x % additional_configs["record_video_trigger"] == 0,
video_length= additional_configs["video_length"],
)
model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"
model_save_name = additional_configs["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"]
model_save_path = rl_agent_checkpoints / model_save_name
filtered_config = {k: v for k, v in config.items() if
k not in ["env_id", "policy_type", "model_type", "total_timesteps", "number_envs_parallel"] and v != 'None'}
k not in ["env_id", "policy_type", "model_name", "model_type", "total_timesteps", "number_envs_parallel"] and v != 'None'}
model = model_class(
env=env,
**filtered_config
......@@ -62,18 +63,19 @@ def main(cfg: DictConfig):
model.learn(
total_timesteps=config["total_timesteps"],
log_interval=1,
progress_bar=True,
progress_bar=additional_configs["progress_bar"],
)
else:
checkpoint_callback = CheckpointCallback(
save_freq=50_000,
save_path="logs",
name_prefix="rl_model",
save_replay_buffer=True,
save_vecnormalize=True,
save_freq=additional_configs["save_freq"],
save_path=additional_configs["save_path_callback"],
name_prefix=additional_configs["name_prefix_callback"],
save_replay_buffer=additional_configs["save_replay_buffer"],
save_vecnormalize=additional_configs["save_vecnormalize"],
)
wandb_callback = WandbCallback(
model_save_path=f"logs/reinforcement_learning/models/{run.id}",
model_save_path=additional_configs["video_save_path"] + run.id,
verbose=0,
)
......@@ -82,7 +84,7 @@ def main(cfg: DictConfig):
total_timesteps=config["total_timesteps"],
callback=callback,
log_interval=1,
progress_bar=True,
progress_bar=additional_configs["progress_bar"],
)
run.finish()
model.save(model_save_path)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment