import json
import os
from typing import Any, Callable, Dict, List, Union, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from mu_map.random_search.cgan import load_params


SIZE_DEFAULT = 12
plt.rc("font", family="Roboto")  # controls default font
plt.rc("font", weight="normal")  # controls default font
plt.rc("font", size=SIZE_DEFAULT)  # controls default text sizes
plt.rc("axes", titlesize=18)  # fontsize of the axes title

class ColorList:

    def __init__(self, *colors):
        self.colors = colors

    def __len__(self):
        return len(self.colors)

    def __getitem__(self, i):
        return self.colors[i % len(self)]

color_lists = {
    "default": ColorList(
        "#1f78b4",
        "#33a02c",
        "#e31a1c",
        "#ff7f00",
        "#cab2d6",
        "#a6cee3",
        "#b2df8a",
        "#fb9a99",
        "#fdbf6f",
        "#6a3d9a",
    ),
    "printer_friendly": ColorList(
        "#1b9e77",
        "#d95f02",
        "#7570b3",
        "#e7298a",
        "#66a61e",
        "#e6ab02",
    )
}

# https://colorbrewer2.org/#type=qualitative&scheme=Dark2&n=5
short_color_list = ColorList(
    "#1b9e77",
    "#d95f02",
    "#7570b3",
    "#e7298a",
    "#66a61e",
)

# https://colorbrewer2.org/#type=qualitative&scheme=Set3&n=10
COLORS = [
    "#8dd3c7",
    "#fb8072",
    "#80b1d3",
    "#fdb462",
    "#b3de69",
    "#fccde5",
    "#d9d9d9",
    "#bc80bd",
    "#ffffb3",
    "#bebada",
]


def jitter(data: np.ndarray, amount: float = 0.1) -> np.ndarray:
    """
    Jitter the all values in an array.

    This is useful to scatter values which are all displayed for
    the same x value. The amount should be chosen in relation to
    the values in the data. For example, if the smallest change
    in x is 1 the amount should be lower than this.

    Parameters
    ----------
    data: np.ndarray
        the data which is jittered
    amount: float
        the maximal value added to the data for jittering
    Returns
    -------
    np.ndarray
    """
    return data + (np.random.rand(*data.shape) - 0.5) * amount


def load_data(
    dir_random_search: str,
    file_measures: str = "measures.csv",
    file_params: str = "params.json",
) -> Dict[int, Dict[str, Any]]:
    dirs_run = sorted(os.listdir(dir_random_search))
    dirs_run = map(lambda f: os.path.join(dir_random_search, f), dirs_run)
    dirs_run = filter(lambda f: os.path.isdir(f), dirs_run)
    dirs_run = filter(lambda f: not os.path.islink(f), dirs_run)
    dirs_run = map(lambda f: os.path.basename(f), dirs_run)

    data = {}
    for dir_run in dirs_run:
        measures = pd.read_csv(os.path.join(dir_random_search, dir_run, file_measures))
        params = load_params(os.path.join(dir_random_search, dir_run, file_params))
        data[int(dir_run)] = {"measures": measures, "params": params, "dir": dir_run}
    return data


def remove_outliers(
    data: Dict[int, Dict[str, Any]], file_outliers: str = "outliers.csv"
):
    outlier_runs = pd.read_csv(file_outliers)
    outlier_runs = outlier_runs[outlier_runs["outlier"]]
    outlier_runs = list(outlier_runs["run"])
    return dict(filter(lambda i: i[0] not in outlier_runs, data.items()))


def filter_by_params(
    data: Dict[int, Dict[str, Any]],
    value: Union[Any, Tuple[Any]],
    fields: Union[str, List[str]],
):
    if type(value) is not tuple:
        value = (value,)

    if type(fields) is not list:
        fields = [fields]

    return dict(
        (k, v)
        for (k, v) in data.items()
        if tuple(map(lambda f: v["params"][f], fields)) == value
    )


class TablePrinter():

    def __init__(self):
        self.vert = "─"
        self.hori = "│"

        self.t_up = "┴"
        self.t_down = "┬"
        self.t_right = "┤"
        self.t_left = "├"

        self.top_left = "┌"
        self.top_right = "┐"
        self.bottom_right = "┘"
        self.bottom_left = "└"

        self.cross = "┼"

        self.formatter = {
            float: "{:.5f}",
            np.float64: "{:.5f}",
        }
        self.color_formatter = {}

    def print(self, table: Dict[str, List[Any]]):
        headers = list(table.keys())

        table = dict([(header, list(map(lambda value: self.format(value, header), column))) for header, column in table.items()])
        lenghtes = dict([(header, max(len(header), *map(len, column))) for header, column in table.items()])

        line_top = f"{self.vert}{self.t_down}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
        line_top = self.top_left + self.vert + line_top + self.vert + self.top_right
        print(line_top)

        line_headers = f" {self.hori} ".join(map(lambda header: f"{header:>{lenghtes[header]}}", table.keys()))
        line_headers = self.hori + " " + line_headers + " " + self.hori
        print(line_headers)

        line_mid = f"{self.vert}{self.cross}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
        line_mid = self.t_left + self.vert + line_mid + self.vert + self.t_right
        print(line_mid)

        for i in range(len(table[headers[0]])):
            values = map(lambda header: self.color(f"{table[header][i]:>{lenghtes[header]}}", header), headers)
            line = f" {self.hori} ".join(values)
            line = self.hori + " " + line + " " + self.hori
            print(line)

        line_bot = f"{self.vert}{self.t_up}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
        line_bot = self.bottom_left + self.vert + line_bot + self.vert + self.bottom_right
        print(line_bot)


    def format(self, value: Any, header: str):
        if header in self.formatter:
            return self.formatter[header].format(value)

        if type(value) in self.formatter:
            return self.formatter[type(value)].format(value)

        return str(value)

    def color(self, value_str: str, header: str):
        if header in self.color_formatter:
            return self.color_formatter[header](value_str)
        return value_str

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("random_search_dir", type=str)
    args = parser.parse_args()

    data = load_data(args.random_search_dir)