Skip to content
Snippets Groups Projects
run_single_agent.py 1.36 KiB
import time
from pathlib import Path

import cv2
from stable_baselines3 import DQN, A2C, PPO

from gym_env import EnvGymWrapper
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate, call


@hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
def main(cfg: DictConfig):
    """
    loads the trained model and enables the user to see an example with the according rewards.
    """
    additional_config = OmegaConf.to_container(cfg.additional_configs, resolve=True)
    model_save_path = Path(additional_config["log_path"]) / Path(additional_config["checkpoint_path"]) / Path(
        additional_config["project_name"] + "_" + OmegaConf.to_container(cfg.model, resolve=True)["model_name"])
    model_class = call(cfg.model.model_type_inference)
    model = model_class(model_save_path)
    env = EnvGymWrapper(cfg)

    # check_env(env)
    obs, info = env.reset()
    print(obs)
    while True:
        action, _states = model.predict(obs, deterministic=False)
        print(action)
        obs, reward, terminated, truncated, info = env.step(int(action))
        print(reward)
        rgb_img = env.render()
        cv2.imshow("env", rgb_img)
        cv2.waitKey(0)
        if terminated or truncated:
            obs, info = env.reset()
        time.sleep(1 / env.metadata["render_fps"])


if __name__ == "__main__":
    main()