Skip to content
Snippets Groups Projects
Commit ce3e444f authored by Christoph Kowalski's avatar Christoph Kowalski
Browse files

Inference uses the same configs as the training and instantiates the model...

Inference uses the same configs as the training and instantiates the model load function from the configs.
parent ffd5586f
No related branches found
No related tags found
3 merge requests!110V1.2.0 changes,!102Fixed caching of recipe layouts. Ids were used in hash, which are generated...,!98Resolve "Restructure Reinforcement Learning files"
Pipeline #57770 passed
order_generator: "random_orders.yaml"
# Here the filename of the converter should be given. The converter class needs to be called StateConverter and implement the abstract StateToObservationConverter class
state_converter:
_target_: "cooperative_cuisine.reinforcement_learning.obs_converter.base_converter.BaseStateConverter"
_target_: "cooperative_cuisine.reinforcement_learning.obs_converter.base_converter_onehot.BaseStateConverterOnehot"
log_path: "logs/reinforcement_learning"
checkpoint_path: "rl_agent_checkpoints"
render_mode: "rgb_array"
......
......@@ -4,6 +4,9 @@ model_name: "A2C"
model_type:
_partial_: true
_target_: stable_baselines3.A2C
model_type_inference:
_partial_: true
_target_: stable_baselines3.A2C.load
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64
learning_rate: 0.0007
......
......@@ -4,6 +4,9 @@ model_name: "DQN"
model_type:
_partial_: true
_target_: stable_baselines3.DQN
model_type_inference:
_partial_: true
_target_: stable_baselines3.DQN.load
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64
learning_rate: 0.0001
......
......@@ -4,6 +4,9 @@ model_name: "PPO"
model_type:
_partial_: true
_target_: stable_baselines3.PPO
model_type_inference:
_partial_: true
_target_: stable_baselines3.PPO.load
total_timesteps: 3_000_000 # hendric sagt eher so 300_000_000 schritte
number_envs_parallel: 64
learning_rate: 0.0003
......
import importlib
import json
import random
from abc import abstractmethod
......@@ -11,6 +10,7 @@ import cv2
import numpy as np
import yaml
from gymnasium import spaces, Env
from hydra.utils import instantiate
from omegaconf import OmegaConf
from cooperative_cuisine import ROOT_DIR
......@@ -20,7 +20,6 @@ from cooperative_cuisine.environment import (
Environment,
)
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer
from hydra.utils import instantiate
class SimpleActionSpace(Enum):
......
......@@ -6,22 +6,25 @@ from stable_baselines3 import DQN, A2C, PPO
from gym_env import EnvGymWrapper
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from hydra.utils import instantiate, call
@hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
def main(cfg: DictConfig):
additional_config = OmegaConf.to_container(cfg.additional_configs, resolve=True)
model_save_path = "logs/reinforcement_learning/rl_agent_checkpoints/overcooked_DQN.zip"
model_save_path = additional_config["log_path"] + "/" + additional_config["checkpoint_path"] +"/"+ additional_config[
"project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"]
model_class = instantiate(cfg.model.model_type)
model = model_class.load(model_save_path)
model_save_path = additional_config["log_path"] + "/" + additional_config["checkpoint_path"] + "/" + \
additional_config["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)[
"model_name"]
model_class = call(cfg.model.model_type_inference)
model = model_class(model_save_path)
env = EnvGymWrapper(cfg)
# check_env(env)
#check_env(env)
obs, info = env.reset()
print(obs)
while True:
action, _states = model.predict(obs, deterministic=False)
print(action)
obs, reward, terminated, truncated, info = env.step(int(action))
print(reward)
rgb_img = env.render()
......
......@@ -54,7 +54,7 @@ def main(cfg: DictConfig):
model_save_name = additional_configs["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"]
model_save_path = rl_agent_checkpoints / model_save_name
filtered_config = {k: v for k, v in config.items() if
k not in ["env_id", "policy_type", "model_name", "model_type", "total_timesteps", "number_envs_parallel"] and v != 'None'}
k not in ["env_id", "policy_type", "model_name", "model_type", "model_type_inference" ,"total_timesteps", "number_envs_parallel"] and v != 'None'}
model = model_class(
env=env,
**filtered_config
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment