import json
import random
from abc import abstractmethod
import uuid
from copy import deepcopy
from datetime import timedelta
from enum import Enum
from pathlib import Path
from uuid import uuid4

import cv2
import numpy as np
import yaml
from gymnasium import spaces, Env
from hydra.utils import instantiate
from omegaconf import OmegaConf

from cooperative_cuisine import ROOT_DIR
from cooperative_cuisine.action import ActionType, InterActionData, Action
from cooperative_cuisine.counters import Counter
from cooperative_cuisine.environment import (
    Environment,
)
from cooperative_cuisine.pygame_2d_vis.drawing import Visualizer, CacheFlags


class SimpleActionSpace(Enum):
    """Enumeration of actions for simple action spaces for an RL agent."""

    Up = "Up"
    Down = "Down"
    Left = "Left"
    Right = "Right"
    Interact = "Interact"
    Put = "Put"


def get_env_action(player_id, simple_action, duration):

    """

    Args:
        player_id: id of the player
        simple_action: an action in the form of a SimpleActionSpace
        duration: for how long an action should be conducted

    Returns: a concrete action

    """

    match simple_action:
        case SimpleActionSpace.Up:
            return Action(
                player_id,
                ActionType.MOVEMENT,
                np.array([0, -1]),
                duration,
            )

        case SimpleActionSpace.Left:
            return Action(
                player_id,
                ActionType.MOVEMENT,
                np.array([-1, 0]),
                duration,
            )
        case SimpleActionSpace.Down:
            return Action(
                player_id,
                ActionType.MOVEMENT,
                np.array([0, 1]),
                duration,
            )
        case SimpleActionSpace.Right:
            return Action(
                player_id,
                ActionType.MOVEMENT,
                np.array([1, 0]),
                duration,
            )
        case SimpleActionSpace.Put:
            return Action(
                player_id,
                ActionType.PICK_UP_DROP,
                InterActionData.START,
                duration,
            )
        case SimpleActionSpace.Interact:
            return Action(
                player_id,
                ActionType.INTERACT,
                InterActionData.START,
                duration,
            )


with open(ROOT_DIR / "pygame_2d_vis" / "visualization.yaml", "r") as file:
    visualization_config = yaml.safe_load(file)

visualizer: Visualizer = Visualizer(config=visualization_config)
visualizer.create_player_colors(1)
visualizer.set_grid_size(40)


def shuffle_counters(env):

    """
    Shuffles the counters of an environment
    Args:
        env: the environment object
    """
    sample_counter = []
    other_counters = []
    for counter in env.counters:
        if counter.__class__ != Counter:
            sample_counter.append(counter)
        else:
            other_counters.append(counter)
    new_counter_pos = [c.pos for c in sample_counter]
    random.shuffle(new_counter_pos)
    for counter, new_pos in zip(sample_counter, new_counter_pos):
        counter.pos = new_pos
    sample_counter.extend(other_counters)
    env.overwrite_counters(sample_counter)


class StateToObservationConverter:
    """
    Abstract definition of a class that gets and environment and outputs a state representation for rl
    """

    @abstractmethod
    def setup(self, env, item_info):
        ...

    @abstractmethod
    def convert_state_to_observation(self, state) -> np.ndarray:
        ...


class EnvGymWrapper(Env):
    """Should enable this:
    observation, reward, terminated, truncated, info = env.step(action)
    """

    metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 10}

    def __init__(self, config):
        """
        Initializes all necessary variables

        Args:
            config:gets the rl and environment configuration from hydra
        """
        super().__init__()

        self.randomize_counter_placement = False
        self.use_rgb_obs = False  # if False uses simple vectorized state
        self.full_vector_state = True
        config_env = OmegaConf.to_container(config.environment, resolve=True)
        config_item_info = OmegaConf.to_container(config.item_info, resolve=True)
        for val in config_env['hook_callbacks']:
            config_env['hook_callbacks'][val]["callback_class"] = instantiate(config_env['hook_callbacks'][val]["callback_class"])
        config_env["orders"]["order_gen_class"] = instantiate(config_env["orders"]["order_generator"])
        self.config_env = config_env
        self.config_item_info = config_item_info
        layout_file = config_env["layout_name"]
        layout_path: Path = ROOT_DIR / layout_file
        with open(layout_path, "r") as file:
            self.layout = file.read()
        self.env: Environment = Environment(
            env_config=deepcopy(config_env),
            layout_config=self.layout,
            item_info=deepcopy(config_item_info),
            as_files=False,
            yaml_already_loaded=True,
            env_name=uuid.uuid4().hex,
        )

        if self.randomize_counter_placement:
            shuffle_counters(self.env)

        self.player_name = str(0)
        self.env.add_player(self.player_name)
        self.player_id = list(self.env.players.keys())[0]

        self.visualizer = visualizer

        # self.action_space = {idx: value for idx, value in enumerate(SimpleActionSpace)}
        self.action_space_map = {}
        for idx, item in enumerate(SimpleActionSpace):
            self.action_space_map[idx] = item
        self.global_step_time = 1
        self.in_between_steps = 1

        self.action_space = spaces.Discrete(len(self.action_space_map))

        self.seen_items = []
        self.converter = instantiate(config.additional_configs.state_converter)
        # self.converter.setup could also get the item info config in order to get all the possible items.
        self.converter.setup(self.env, self.config_item_info)
        if hasattr(self.converter, "onehot"):
            self.onehot_state = self.converter.onehot
        else:
            self.onehot_state = 'onehot' in config.additional_configs.state_converter.lower()
        dummy_obs = self.get_observation()
        min_obs_val = -1 if not self.use_rgb_obs else 0
        max_obs_val = 255 if self.use_rgb_obs else 1 if self.onehot_state else 20
        self.observation_space = spaces.Box(
            low=min_obs_val,
            high=max_obs_val,
            shape=dummy_obs.shape,
            dtype=np.uint8 if self.use_rgb_obs else int,
        )

        self.last_obs = dummy_obs
        self.step_counter = 0
        self.prev_score = 0

    def step(self, action):
        """
        takes one step in the environment and returns the observation, reward, info whether terminated, truncated
        and additional information
        """
        # this is simply a work-around to enable no action which is necessary for the play_gym.py
        if action == 8:
            observation = self.get_observation()
            reward = self.env.score - self.prev_score
            terminated = self.env.game_ended
            truncated = self.env.game_ended
            info = {}
            return observation, reward, terminated, truncated, info
        simple_action = self.action_space_map[action]
        env_action = get_env_action(
            self.player_id, simple_action, self.global_step_time
        )
        self.env.perform_action(env_action)

        for i in range(self.in_between_steps):
            self.env.step(
                timedelta(seconds=self.global_step_time / self.in_between_steps)
            )

        observation = self.get_observation()

        reward = self.env.score - self.prev_score
        self.prev_score = self.env.score

        if reward > 0.6:
            print("- - - - - - - - - - - - - - - - SCORED", reward)

        terminated = self.env.game_ended
        truncated = self.env.game_ended
        info = {}

        return observation, reward, terminated, truncated, info

    def reset(self, seed=None, options=None):
        """
        Resets the environment according to the configs
        """
        if self.env.env_name in visualizer.surface_cache_dict:
            del visualizer.surface_cache_dict[self.env.env_name]
        self.env: Environment = Environment(
            env_config=deepcopy(self.config_env),
            layout_config=self.layout,
            item_info=deepcopy(self.config_item_info),
            as_files=False,
            env_name=uuid.uuid4().hex,
            yaml_already_loaded=True,
        )

        if self.randomize_counter_placement:
            shuffle_counters(self.env)

        self.player_name = str(0)
        self.env.add_player(self.player_name)
        self.player_id = list(self.env.players.keys())[0]

        info = {}
        obs = self.get_observation()

        self.prev_score = 0

        return obs, info

    def get_observation(self):
        if self.use_rgb_obs:
            obs = self.get_env_img()
        else:
            obs = self.get_vector_state()
        return obs

    def render(self):
        return self.get_env_img()

    def close(self):
        pass

    def get_env_img(self):
        state = self.env.get_json_state(player_id=self.player_id)
        json_dict = json.loads(state)
        observation = self.visualizer.get_state_image(state=json_dict, env_id_ref=self.env.env_name).astype(np.uint8)
        return observation

    def get_vector_state(self):
        obs = self.converter.convert_state_to_observation(self.env)
        return obs

    def sample_random_action(self):
        act = self.action_space.sample()
        return act