Skip to content
Snippets Groups Projects
Commit ce3e444f authored by Christoph Kowalski's avatar Christoph Kowalski
Browse files

Inference uses the same configs as the training and instantiates the model...

Inference uses the same configs as the training and instantiates the model load function from the configs.
parent ffd5586f
No related branches found
No related tags found
3 merge requests!110V1.2.0 changes,!102Fixed caching of recipe layouts. Ids were used in hash, which are generated...,!98Resolve "Restructure Reinforcement Learning files"
Pipeline #57770 passed
order_generator: "random_orders.yaml" order_generator: "random_orders.yaml"
# Here the filename of the converter should be given. The converter class needs to be called StateConverter and implement the abstract StateToObservationConverter class # Here the filename of the converter should be given. The converter class needs to be called StateConverter and implement the abstract StateToObservationConverter class
state_converter: state_converter:
_target_: "cooperative_cuisine.reinforcement_learning.obs_converter.base_converter.BaseStateConverter" _target_: "cooperative_cuisine.reinforcement_learning.obs_converter.base_converter_onehot.BaseStateConverterOnehot"
log_path: "logs/reinforcement_learning" log_path: "logs/reinforcement_learning"
checkpoint_path: "rl_agent_checkpoints" checkpoint_path: "rl_agent_checkpoints"
render_mode: "rgb_array" render_mode: "rgb_array"
......
...@@ -4,6 +4,9 @@ model_name: "A2C" ...@@ -4,6 +4,9 @@ model_name: "A2C"
model_type: model_type:
_partial_: true _partial_: true
_target_: stable_baselines3.A2C _target_: stable_baselines3.A2C
model_type_inference:
_partial_: true
_target_: stable_baselines3.A2C.load
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64 number_envs_parallel: 64
learning_rate: 0.0007 learning_rate: 0.0007
......
...@@ -4,6 +4,9 @@ model_name: "DQN" ...@@ -4,6 +4,9 @@ model_name: "DQN"
model_type: model_type:
_partial_: true _partial_: true
_target_: stable_baselines3.DQN _target_: stable_baselines3.DQN
model_type_inference:
_partial_: true
_target_: stable_baselines3.DQN.load
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64 number_envs_parallel: 64
learning_rate: 0.0001 learning_rate: 0.0001
......
...@@ -4,6 +4,9 @@ model_name: "PPO" ...@@ -4,6 +4,9 @@ model_name: "PPO"
model_type: model_type:
_partial_: true _partial_: true
_target_: stable_baselines3.PPO _target_: stable_baselines3.PPO
model_type_inference:
_partial_: true
_target_: stable_baselines3.PPO.load
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64 number_envs_parallel: 64
learning_rate: 0.0003 learning_rate: 0.0003
......
import importlib
import json import json
import random import random
from abc import abstractmethod from abc import abstractmethod
...@@ -11,6 +10,7 @@ import cv2 ...@@ -11,6 +10,7 @@ import cv2
import numpy as np import numpy as np
import yaml import yaml
from gymnasium import spaces, Env from gymnasium import spaces, Env
from hydra.utils import instantiate
from omegaconf import OmegaConf from omegaconf import OmegaConf
from cooperative_cuisine import ROOT_DIR from cooperative_cuisine import ROOT_DIR
...@@ -20,7 +20,6 @@ from cooperative_cuisine.environment import ( ...@@ -20,7 +20,6 @@ from cooperative_cuisine.environment import (
Environment, Environment,
) )
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer
from hydra.utils import instantiate
class SimpleActionSpace(Enum): class SimpleActionSpace(Enum):
......
...@@ -6,22 +6,25 @@ from stable_baselines3 import DQN, A2C, PPO ...@@ -6,22 +6,25 @@ from stable_baselines3 import DQN, A2C, PPO
from gym_env import EnvGymWrapper from gym_env import EnvGymWrapper
import hydra import hydra
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate from hydra.utils import instantiate, call
@hydra.main(version_base="1.3", config_path="config", config_name="rl_config") @hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
def main(cfg: DictConfig): def main(cfg: DictConfig):
additional_config = OmegaConf.to_container(cfg.additional_configs, resolve=True) additional_config = OmegaConf.to_container(cfg.additional_configs, resolve=True)
model_save_path = "logs/reinforcement_learning/rl_agent_checkpoints/overcooked_DQN.zip" model_save_path = additional_config["log_path"] + "/" + additional_config["checkpoint_path"] + "/" + \
model_save_path = additional_config["log_path"] + "/" + additional_config["checkpoint_path"] +"/"+ additional_config[ additional_config["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)[
"project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"] "model_name"]
model_class = instantiate(cfg.model.model_type) model_class = call(cfg.model.model_type_inference)
model = model_class.load(model_save_path) model = model_class(model_save_path)
env = EnvGymWrapper(cfg) env = EnvGymWrapper(cfg)
# check_env(env) #check_env(env)
obs, info = env.reset() obs, info = env.reset()
print(obs)
while True: while True:
action, _states = model.predict(obs, deterministic=False) action, _states = model.predict(obs, deterministic=False)
print(action)
obs, reward, terminated, truncated, info = env.step(int(action)) obs, reward, terminated, truncated, info = env.step(int(action))
print(reward) print(reward)
rgb_img = env.render() rgb_img = env.render()
......
...@@ -54,7 +54,7 @@ def main(cfg: DictConfig): ...@@ -54,7 +54,7 @@ def main(cfg: DictConfig):
model_save_name = additional_configs["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"] model_save_name = additional_configs["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"]
model_save_path = rl_agent_checkpoints / model_save_name model_save_path = rl_agent_checkpoints / model_save_name
filtered_config = {k: v for k, v in config.items() if filtered_config = {k: v for k, v in config.items() if
k not in ["env_id", "policy_type", "model_name", "model_type", "total_timesteps", "number_envs_parallel"] and v != 'None'} k not in ["env_id", "policy_type", "model_name", "model_type", "model_type_inference" ,"total_timesteps", "number_envs_parallel"] and v != 'None'}
model = model_class( model = model_class(
env=env, env=env,
**filtered_config **filtered_config
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment