from pathlib import Path

import wandb
from omegaconf import DictConfig, OmegaConf
from stable_baselines3 import A2C
from stable_baselines3 import DQN
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder
from wandb.integration.sb3 import WandbCallback

from gym_env import EnvGymWrapper
import hydra


@hydra.main(version_base="1.3", config_path="config", config_name="rl_config")
def main(cfg: DictConfig):
    rl_logs = Path("logs/reinforcement_learning")
    rl_logs.mkdir(exist_ok=True)
    rl_agent_checkpoints = Path("logs/reinforcement_learning/rl_agent_checkpoints")
    rl_agent_checkpoints.mkdir(exist_ok=True)
    config = OmegaConf.to_container(cfg.model, resolve=True)
    debug = False
    do_training = True
    vec_env = True
    models = {"A2C": A2C, "DQN": DQN, "PPO": PPO}
    number_envs_parallel = config["number_envs_parallel"]
    model_class = models[config["model_type"]]
    if vec_env:
        env = make_vec_env(lambda: EnvGymWrapper(cfg), n_envs=number_envs_parallel)
    else:
        env = EnvGymWrapper(cfg)

    env.render_mode = "rgb_array"
    if not debug:
        # also upload the environment config to W&B and all stable baselines3 hyperparams
        run = wandb.init(
            project="overcooked",
            config=config,
            sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
            monitor_gym=True,
            dir="logs/reinforcement_learning"
            # save_code=True,  # optional
        )

        env = VecVideoRecorder(
            env,
            f"logs/reinforcement_learning/videos/{run.id}",
            record_video_trigger=lambda x: x % 200_000 == 0,
            video_length=300,
        )

    model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"

    if do_training:
        model = model_class(
            config["policy_type"],
            env,
            verbose=1,
            tensorboard_log=f"logs/reinforcement_learning/runs/{0}",
            device="cpu"
            # n_steps=2048,
            # n_epochs=10,
        )

        # Maybe Hydra Instatiate here to avoid hard coding the possible classes
        if debug:
            model.learn(
                total_timesteps=config["total_timesteps"],
                log_interval=1,
                progress_bar=True,
            )
        else:
            checkpoint_callback = CheckpointCallback(
                save_freq=50_000,
                save_path="logs",
                name_prefix="rl_model",
                save_replay_buffer=True,
                save_vecnormalize=True,
            )
            wandb_callback = WandbCallback(
                model_save_path=f"logs/reinforcement_learning/models/{run.id}",
                verbose=0,
            )

            callback = CallbackList([checkpoint_callback, wandb_callback])
            model.learn(
                total_timesteps=config["total_timesteps"],
                callback=callback,
                log_interval=1,
                progress_bar=True,
            )
            run.finish()
        model.save(model_save_path)

        del model
    print("LEARNING DONE.")


if __name__ == "__main__":
    main()