import json
import random
import time
import uuid
from collections import deque
from datetime import timedelta
from enum import Enum
from pathlib import Path
from uuid import uuid4

import cv2
import numpy as np
import wandb
import yaml
from gymnasium import spaces, Env
from stable_baselines3 import A2C
from stable_baselines3 import DQN
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecVideoRecorder
from wandb.integration.sb3 import WandbCallback

from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.counters import Counter, CookingCounter, Dispenser
from cooperative_cuisine.environment import (
    Environment,
)
from cooperative_cuisine.items import CookingEquipment
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags


class SimpleActionSpace(Enum):
    """Enumeration of actions for simple action spaces for an RL agent."""

    Up = "Up"
    Down = "Down"
    Left = "Left"
    Right = "Right"
    Interact = "Interact"
    Put = "Put"


def get_env_action(player_id, simple_action, duration):
    match simple_action:
        case SimpleActionSpace.Up:
            return Action(
                player_id,
                ActionType.MOVEMENT,
                np.array([0, -1]),
                duration,
            )

        case SimpleActionSpace.Left:
            return Action(
                player_id,
                ActionType.MOVEMENT,
                np.array([-1, 0]),
                duration,
            )
        case SimpleActionSpace.Down:
            return Action(
                player_id,
                ActionType.MOVEMENT,
                np.array([0, 1]),
                duration,
            )
        case SimpleActionSpace.Right:
            return Action(
                player_id,
                ActionType.MOVEMENT,
                np.array([1, 0]),
                duration,
            )
        case SimpleActionSpace.Put:
            return Action(
                player_id,
                ActionType.PICK_UP_DROP,
                InterActionData.START,
                duration,
            )
        case SimpleActionSpace.Interact:
            return Action(
                player_id,
                ActionType.INTERACT,
                InterActionData.START,
                duration,
            )


environment_config_path = (
        ROOT_DIR / "reinforcement_learning" / "environment_config_rl.yaml"
)
layout_path: Path = ROOT_DIR / "reinforcement_learning" / "rl_small.layout"
item_info_path = ROOT_DIR / "reinforcement_learning" / "item_info_rl.yaml"
with open(item_info_path, "r") as file:
    item_info = file.read()
with open(layout_path, "r") as file:
    layout = file.read()
with open(environment_config_path, "r") as file:
    environment_config = file.read()
with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file:
    visualization_config = yaml.safe_load(file)

visualizer: Visualizer = Visualizer(config=visualization_config)
visualizer.create_player_colors(1)
visualizer.set_grid_size(40)


def shuffle_counters(env):
    sample_counter = []
    other_counters = []
    for counter in env.counters:
        if counter.__class__ != Counter:
            sample_counter.append(counter)
        else:
            other_counters.append(counter)
    new_counter_pos = [c.pos for c in sample_counter]
    random.shuffle(new_counter_pos)
    for counter, new_pos in zip(sample_counter, new_counter_pos):
        counter.pos = new_pos
    sample_counter.extend(other_counters)
    env.overwrite_counters(sample_counter)


class EnvGymWrapper(Env):
    """Should enable this:
    observation, reward, terminated, truncated, info = env.step(action)
    """

    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}

    def __init__(self):
        super().__init__()


        self.randomize_counter_placement = False
        self.use_rgb_obs = True  # if False uses simple vectorized state
        self.full_vector_state = True
        self.onehot_state = True

        self.env: Environment = Environment(
            env_config=environment_config,
            layout_config=layout,
            item_info=item_info,
            as_files=False,
            env_name=uuid.uuid4().hex,
        )

        if self.randomize_counter_placement:
            shuffle_counters(self.env)

        self.player_name = str(0)
        self.env.add_player(self.player_name)
        self.player_id = list(self.env.players.keys())[0]

        self.visualizer = visualizer

        # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
        self.action_space_map = {}
        for idx, item in enumerate(SimpleActionSpace):
            self.action_space_map[idx] = item
        self.global_step_time = 1
        self.in_between_steps = 1

        self.action_space = spaces.Discrete(len(self.action_space_map))

        self.seen_items = []

        dummy_obs = self.get_observation()
        min_obs_val = -1 if not self.use_rgb_obs else 0
        max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
        self.observation_space = spaces.Box(
            low=min_obs_val,
            high=max_obs_val,
            shape=dummy_obs.shape,
            dtype=np.uint8 if self.use_rgb_obs else int,
        )

        self.last_obs = dummy_obs

        self.step_counter = 0
        self.prev_score = 0

    def vectorize_item(self, item, item_list):
        item_one_hot = np.zeros(len(item_list))
        if item is None:
            item_name = "None"
        elif isinstance(item, deque):
            if len(item) > 0:
                item_name = item[0].name
            else:
                item_name = "None"
        else:
            item_name = item.name

        if isinstance(item, CookingEquipment):
            if item.name == "Pot":
                if len(item.content_list) > 0:
                    if item.content_list[0].name == "TomatoSoup":
                        item_name = "PotDone"
                    elif len(item.content_list) == 1:
                        item_name = "PotOne"
                    elif len(item.content_list) == 2:
                        item_name = "PotTwo"
                    elif len(item.content_list) == 3:
                        item_name = "PotThree"
            if "Plate" in item.name:
                content_list = [i.name for i in item.content_list]
                match content_list:
                    case ["TomatoSoup"]:
                        item_name = "PlateTomatoSoup"
                    case ["ChoppedTomato"]:
                        item_name = "PlateChoppedTomato"
                    case ["ChoppedLettuce"]:
                        item_name = "PlateChoppedLettuce"
                    case []:
                        item_name = "Plate"
                    case ["ChoppedLettuce", "ChoppedTomato"]:
                        item_name = "PlateSalad"
                    case other:
                        assert False, f"Should not happen. {item}"
        assert item_name in item_list, f"Unknown item {item_name}."
        item_idx = item_list.index(item_name)
        item_one_hot[item_idx] = 1

        # if item_name not in self.seen_items:
        #     print(item, item_name)
        #     self.seen_items.append(item_name)

        return item_one_hot, item_idx

    @staticmethod
    def vectorize_counter(counter, counter_list):
        counter_name = (
            counter.name
            if isinstance(counter, CookingCounter)
            else (
                repr(counter)
                if isinstance(Counter, Dispenser)
                else counter.__class__.__name__
            )
        )
        if counter_name == "Dispenser":
            counter_name = f"{counter.occupied_by.name}Dispenser"
        assert counter_name in counter_list, f"Unknown Counter {counter}"

        counter_oh_idx = counter_list.index("Empty")
        if counter_name in counter_list:
            counter_oh_idx = counter_list.index(counter_name)

        counter_one_hot = np.zeros(len(counter_list), dtype=int)
        counter_one_hot[counter_oh_idx] = 1
        return counter_one_hot, counter_oh_idx

    def get_vectorized_state_simple(self, player, onehot=True):
        counter_list = [
            "Empty",
            "Counter",
            "PlateDispenser",
            "TomatoDispenser",
            "ServingWindow",
            "PlateReturn",
            "Trashcan",
            "Stove",
            "CuttingBoard",
            "LettuceDispenser",
        ]

        item_list = [
            "None",
            "Pot",
            "PotOne",
            "PotTwo",
            "PotThree",
            "PotDone",
            "Tomato",
            "ChoppedTomato",
            "Plate",
            "PlateTomatoSoup",
            "PlateSalad",
            "Lettuce",
            "PlateChoppedTomato",
            "PlateChoppedLettuce",
            "ChoppedLettuce",
        ]

        grid_width, grid_height = int(self.env.kitchen_width), int(
            self.env.kitchen_height
        )

        grid_base_array = np.zeros(
            (
                grid_width,
                grid_height,
            ),
            dtype=int,
        )
        grid_idxs = [(x, y) for x in range(grid_width) for y in range(grid_height)]

        if onehot:
            item_one_hot_length = len(item_list)
            counter_items = np.zeros(
                (grid_width, grid_height, item_one_hot_length), dtype=int
            )
            counter_one_hot_length = len(counter_list)
            counters = np.zeros(
                (grid_width, grid_height, counter_one_hot_length), dtype=int
            )
        else:
            counter_items = np.zeros((grid_width, grid_height), dtype=int)
            counters = np.zeros((grid_width, grid_height), dtype=int)

        for counter in self.env.counters:
            grid_idx = np.floor(counter.pos).astype(int)

            counter_one_hot, counter_oh_idx = self.vectorize_counter(
                counter, counter_list
            )
            grid_base_array[grid_idx[0], grid_idx[1]] = counter_oh_idx
            grid_idxs.remove((int(grid_idx[0]), int(grid_idx[1])))

            counter_item_one_hot, counter_item_oh_idx = self.vectorize_item(
                counter.occupied_by, item_list
            )
            counter_items[grid_idx] = (
                counter_item_one_hot if onehot else counter_item_oh_idx
            )
            counters[grid_idx] = counter_one_hot if onehot else counter_oh_idx

        for free_idx in grid_idxs:
            grid_base_array[free_idx[0], free_idx[1]] = counter_list.index("Empty")

        player_pos = self.env.players[player].pos.astype(int)
        player_dir = self.env.players[player].facing_direction.astype(int)
        player_data = np.concatenate((player_pos, player_dir), axis=0)

        player_item_one_hot, player_item_idx = self.vectorize_item(
            self.env.players[player].holding, item_list
        )
        player_item = player_item_one_hot if onehot else [player_item_idx]

        final = np.concatenate(
            (
                counters.flatten(),
                counter_items.flatten(),
                player_data.flatten(),
                player_item,
            ),
            axis=0,
        )
        return final

    def step(self, action):
        simple_action = self.action_space_map[action]
        env_action = get_env_action(
            self.player_id, simple_action, self.global_step_time
        )
        self.env.perform_action(env_action)

        for i in range(self.in_between_steps):
            self.env.step(
                timedelta(seconds=self.global_step_time / self.in_between_steps)
            )

        observation = self.get_observation()

        reward = self.env.score - self.prev_score
        self.prev_score = self.env.score

        if reward > 0.6:
            print("- - - - - - - - - - - - - - - - SCORED", reward)

        terminated = self.env.game_ended
        truncated = self.env.game_ended
        info = {}

        return observation, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        del visualizer.surface_cache_dict[self.env.env_name]
        self.env: Environment = Environment(
            env_config=environment_config,
            layout_config=layout,
            item_info=item_info,
            as_files=False,
            env_name=uuid.uuid4().hex,
        )

        if self.randomize_counter_placement:
            shuffle_counters(self.env)

        self.player_name = str(0)
        self.env.add_player(self.player_name)
        self.player_id = list(self.env.players.keys())[0]

        info = {}
        obs = self.get_observation()

        self.prev_score = 0

        return obs, info

    def get_observation(self):
        if self.use_rgb_obs:
            obs = self.get_env_img()
        else:
            obs = self.get_vector_state()
        return obs

    def render(self):
        return self.get_env_img()

    def close(self):
        pass

    def get_env_img(self):
        state = self.env.get_json_state(player_id=self.player_id)
        json_dict = json.loads(state)
        observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8)
        return observation

    def get_vector_state(self):
        obs = self.get_vectorized_state_simple("0", self.onehot_state)
        return obs

    def sample_random_action(self):
        act = self.action_space.sample()
        return act


def main():
    rl_agent_checkpoints = Path("rl_agent_checkpoints")
    rl_agent_checkpoints.mkdir(exist_ok=True)

    config = {
        "policy_type": "MlpPolicy",
        "total_timesteps": 3_000_000,  # hendric sagt eher so 300_000_000 schritte
        "env_id": "overcooked",
        "number_envs_parallel": 4,
    }

    debug = True
    do_training = True
    vec_env = True
    number_envs_parallel = config["number_envs_parallel"]

    model_classes = [A2C, DQN, PPO]
    model_class = model_classes[1]

    if vec_env:
        env = make_vec_env(EnvGymWrapper, n_envs=number_envs_parallel)
    else:
        env = EnvGymWrapper()

    env.render_mode = "rgb_array"

    if not debug:
        run = wandb.init(
            project="overcooked",
            config=config,
            sync_tensorboard=True,  # auto-upload sb3's tensorboard metrics
            monitor_gym=True,
            # save_code=True,  # optional
        )

        env = VecVideoRecorder(
            env,
            f"videos/{run.id}",
            record_video_trigger=lambda x: x % 200_000 == 0,
            video_length=300,
        )

    model_save_path = rl_agent_checkpoints / f"overcooked_{model_class.__name__}"

    if do_training:
        model = model_class(
            config["policy_type"],
            env,
            verbose=1,
            tensorboard_log=f"runs/{0}",
            device="cpu"
            # n_steps=2048,
            # n_epochs=10,
        )
        if debug:
            model.learn(
                total_timesteps=config["total_timesteps"],
                log_interval=1,
                progress_bar=True,
            )
        else:
            checkpoint_callback = CheckpointCallback(
                save_freq=50_000,
                save_path="logs",
                name_prefix="rl_model",
                save_replay_buffer=True,
                save_vecnormalize=True,
            )
            wandb_callback = WandbCallback(
                model_save_path=f"models/{run.id}",
                verbose=0,
            )

            callback = CallbackList([checkpoint_callback, wandb_callback])
            model.learn(
                total_timesteps=config["total_timesteps"],
                callback=callback,
                log_interval=1,
                progress_bar=True,
            )
            run.finish()
        model.save(model_save_path)

        del model
    print("LEARNING DONE.")

    model = model_class.load(model_save_path)
    env = EnvGymWrapper()

    check_env(env)
    obs, info = env.reset()
    while True:
        action, _states = model.predict(obs, deterministic=False)
        obs, reward, terminated, truncated, info = env.step(int(action))
        print(reward)
        rgb_img = env.render()
        cv2.imshow("env", rgb_img)
        cv2.waitKey(0)
        if terminated or truncated:
            obs, info = env.reset()
        time.sleep(1 / env.metadata["render_fps"])


if __name__ == "__main__":
    main()