Skip to content
Snippets Groups Projects
util.py 6.16 KiB
Newer Older
import json
import os
from typing import Any, Callable, Dict, List, Union, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from mu_map.random_search.cgan import load_params


SIZE_DEFAULT = 12
plt.rc("font", family="Roboto")  # controls default font
plt.rc("font", weight="normal")  # controls default font
plt.rc("font", size=SIZE_DEFAULT)  # controls default text sizes
plt.rc("axes", titlesize=18)  # fontsize of the axes title

class ColorList:

    def __init__(self, *colors):
        self.colors = colors

    def __len__(self):
        return len(self.colors)

    def __getitem__(self, i):
        return self.colors[i % len(self)]

color_lists = {
    "default": ColorList(
        "#1f78b4",
        "#33a02c",
        "#e31a1c",
        "#ff7f00",
        "#cab2d6",
        "#a6cee3",
        "#b2df8a",
        "#fb9a99",
        "#fdbf6f",
        "#6a3d9a",
    ),
    "printer_friendly": ColorList(
        "#1b9e77",
        "#d95f02",
        "#7570b3",
        "#e7298a",
        "#66a61e",
        "#e6ab02",
    )
}

# https://colorbrewer2.org/#type=qualitative&scheme=Dark2&n=5
short_color_list = ColorList(
    "#1b9e77",
    "#d95f02",
    "#7570b3",
    "#e7298a",
    "#66a61e",
)

# https://colorbrewer2.org/#type=qualitative&scheme=Set3&n=10
COLORS = [
    "#8dd3c7",
    "#fb8072",
    "#80b1d3",
    "#fdb462",
    "#b3de69",
    "#fccde5",
    "#d9d9d9",
    "#bc80bd",
    "#ffffb3",
    "#bebada",
]


def jitter(data: np.ndarray, amount: float = 0.1) -> np.ndarray:
    """
    Jitter the all values in an array.

    This is useful to scatter values which are all displayed for
    the same x value. The amount should be chosen in relation to
    the values in the data. For example, if the smallest change
    in x is 1 the amount should be lower than this.

    Parameters
    ----------
    data: np.ndarray
        the data which is jittered
    amount: float
        the maximal value added to the data for jittering
    Returns
    -------
    np.ndarray
    """
    return data + (np.random.rand(*data.shape) - 0.5) * amount


def load_data(
    dir_random_search: str,
    file_measures: str = "measures.csv",
    file_params: str = "params.json",
) -> Dict[int, Dict[str, Any]]:
    dirs_run = sorted(os.listdir(dir_random_search))
    dirs_run = map(lambda f: os.path.join(dir_random_search, f), dirs_run)
    dirs_run = filter(lambda f: os.path.isdir(f), dirs_run)
    dirs_run = filter(lambda f: not os.path.islink(f), dirs_run)
    dirs_run = map(lambda f: os.path.basename(f), dirs_run)

    data = {}
    for dir_run in dirs_run:
        measures = pd.read_csv(os.path.join(dir_random_search, dir_run, file_measures))
        params = load_params(os.path.join(dir_random_search, dir_run, file_params))
        data[int(dir_run)] = {"measures": measures, "params": params, "dir": dir_run}
    return data


def remove_outliers(
    data: Dict[int, Dict[str, Any]], file_outliers: str = "outliers.csv"
):
    outlier_runs = pd.read_csv(file_outliers)
    outlier_runs = outlier_runs[outlier_runs["outlier"]]
    outlier_runs = list(outlier_runs["run"])
    return dict(filter(lambda i: i[0] not in outlier_runs, data.items()))


def filter_by_params(
    data: Dict[int, Dict[str, Any]],
    value: Union[Any, Tuple[Any]],
    fields: Union[str, List[str]],
):
    if type(value) is not tuple:
        value = (value,)

    if type(fields) is not list:
        fields = [fields]

    return dict(
        (k, v)
        for (k, v) in data.items()
        if tuple(map(lambda f: v["params"][f], fields)) == value
    )


class TablePrinter():

    def __init__(self):
        self.vert = ""
        self.hori = ""

        self.t_up = ""
        self.t_down = ""
        self.t_right = ""
        self.t_left = ""

        self.top_left = ""
        self.top_right = ""
        self.bottom_right = ""
        self.bottom_left = ""

        self.cross = ""

        self.formatter = {
            float: "{:.5f}",
            np.float64: "{:.5f}",
        }
        self.color_formatter = {}

    def print(self, table: Dict[str, List[Any]]):
        headers = list(table.keys())

        table = dict([(header, list(map(lambda value: self.format(value, header), column))) for header, column in table.items()])
        lenghtes = dict([(header, max(len(header), *map(len, column))) for header, column in table.items()])

        line_top = f"{self.vert}{self.t_down}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
        line_top = self.top_left + self.vert + line_top + self.vert + self.top_right
        print(line_top)

        line_headers = f" {self.hori} ".join(map(lambda header: f"{header:>{lenghtes[header]}}", table.keys()))
        line_headers = self.hori + " " + line_headers + " " + self.hori
        print(line_headers)

        line_mid = f"{self.vert}{self.cross}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
        line_mid = self.t_left + self.vert + line_mid + self.vert + self.t_right
        print(line_mid)

        for i in range(len(table[headers[0]])):
            values = map(lambda header: self.color(f"{table[header][i]:>{lenghtes[header]}}", header), headers)
            line = f" {self.hori} ".join(values)
            line = self.hori + " " + line + " " + self.hori
            print(line)

        line_bot = f"{self.vert}{self.t_up}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
        line_bot = self.bottom_left + self.vert + line_bot + self.vert + self.bottom_right
        print(line_bot)


    def format(self, value: Any, header: str):
        if header in self.formatter:
            return self.formatter[header].format(value)

        if type(value) in self.formatter:
            return self.formatter[type(value)].format(value)

        return str(value)

    def color(self, value_str: str, header: str):
        if header in self.color_formatter:
            return self.color_formatter[header](value_str)
        return value_str

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("random_search_dir", type=str)
    args = parser.parse_args()

    data = load_data(args.random_search_dir)