import time

import cv2
from stable_baselines3 import DQN

from gym_env import EnvGymWrapper

model_save_path = "logs/reinforcement_learning/rl_agent_checkpoints/overcooked_DQN.zip"
model_class = DQN
model = model_class.load(model_save_path)
env = EnvGymWrapper()

# check_env(env)
obs, info = env.reset()
while True:
    action, _states = model.predict(obs, deterministic=False)
    obs, reward, terminated, truncated, info = env.step(int(action))
    print(reward)
    rgb_img = env.render()
    cv2.imshow("env", rgb_img)
    cv2.waitKey(0)
    if terminated or truncated:
        obs, info = env.reset()
    time.sleep(1 / env.metadata["render_fps"])