Skip to content
Snippets Groups Projects
util.py 9.12 KiB
Newer Older
"""
Utility functions and classes for the random search evaluation.
"""
import json
import os
from typing import Any, Callable, Dict, List, Union, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from mu_map.random_search.cgan import load_params


plt.rc("font", family="Roboto")  # controls default font
plt.rc("font", weight="normal")  # controls default font
plt.rc("font", size=SIZE_DEFAULT)  # controls default text sizes
# plt.rc("axes", titlesize=18)  # fontsize of the axes title
plt.rc("axes", titlesize=22)  # fontsize of the axes title
class ColorList:
    """
    Class wrapping a list so that indexing values wraps around.
    """
    def __init__(self, *colors: str):
        self.colors = colors

    def __len__(self):
        return len(self.colors)

    def __getitem__(self, i):
        return self.colors[i % len(self)]


"""
Two examples of color lists defined with the help of:
https://colorbrewer2.org
"""
color_lists = {
    "default": ColorList(
        "#1f78b4",
        "#33a02c",
        "#e31a1c",
        "#ff7f00",
        "#cab2d6",
        "#a6cee3",
        "#b2df8a",
        "#fb9a99",
        "#fdbf6f",
        "#6a3d9a",
    ),
    "printer_friendly": ColorList(
        "#1b9e77",
        "#d95f02",
        "#7570b3",
        "#e7298a",
        "#66a61e",
        "#e6ab02",
}


def jitter(data: np.ndarray, amount: float = 0.1) -> np.ndarray:
    """
    Jitter the values in an array.

    This is useful to scatter values which are all displayed for
    the same x value. The amount should be chosen in relation to
    the values in the data. For example, if the smallest change
    in x is 1 the amount should be lower than this.

    Parameters
    ----------
    data: np.ndarray
        the data which is jittered
    amount: float
        the maximal value added to the data for jittering
    Returns
    -------
    np.ndarray
    """
    return data + (np.random.rand(*data.shape) - 0.5) * amount


def load_data(
    dir_random_search: str,
    file_measures: str = "measures.csv",
    file_params: str = "params.json",
) -> Dict[int, Dict[str, Any]]:
    """
    Load results of a random search procedure into a dict.

    The dict is the mapping of an iteration id to its directory,
    the measures CSV file and the params JSON file.

    Parameters
    ----------
    dir_random_search: str
        the directory of the random search procedure
        it is expected that each iteration is stored in a separate numbered directory
    file_measures: str, optional
        filename of the measures CSV file in each directory
    file_params: str, optional
        filename of the params JSON file in each directory

    Returns
    -------
    Dict[int, Dict[str, Any]]
        a dict mapping the iteration number to "measures", "params", and "dir"
    """
    dirs_run = sorted(os.listdir(dir_random_search))
    dirs_run = filter(lambda f: f.isdigit(), dirs_run)
    dirs_run = map(lambda f: os.path.join(dir_random_search, f), dirs_run)
    dirs_run = filter(lambda f: os.path.isdir(f), dirs_run)
    dirs_run = map(lambda f: os.path.basename(f), dirs_run)

    data = {}
    for dir_run in dirs_run:
        measures = pd.read_csv(os.path.join(dir_random_search, dir_run, file_measures))
        params = load_params(os.path.join(dir_random_search, dir_run, file_params))
        data[int(dir_run)] = {"measures": measures, "params": params, "dir": dir_run}
    return data


def remove_outliers(
    data: Dict[int, Dict[str, Any]], file_outliers: str = "outliers.csv"
) -> Dict[int, Dict[str, Any]]:
    """
    Remove outlier iterations from data loaded with `load_data`.

    Parameters
    ----------
    data: Dict[int, Dict[str, Any]]
        data loaded by `load_data`
    file_outliers: str
        CSV file defining outliers (as generated by `mu_map.random_search.eval.label_outliers`)

    Returns
    -------
    Dict[in, Dict[str, Any]]
        filtered data dict
    """
    outlier_runs = pd.read_csv(file_outliers)
    outlier_runs = outlier_runs[outlier_runs["outlier"]]
    outlier_runs = list(outlier_runs["run"])
    return dict(filter(lambda i: i[0] not in outlier_runs, data.items()))


def filter_by_params(
    data: Dict[int, Dict[str, Any]],
    value: Union[Any, Tuple[Any]],
    fields: Union[str, List[str]],
) -> Dict[int, Dict[str, Any]]:
    """
    Filter data loaded with `load_data` based on certain parameters.

    Parameters
    ----------
    data: Dict[int, Dict[str, Any]]
        data loaded by `load_data`
    value: Any or tuple of Any
        values for the parameters
    fields: str or list of str
        field names to access the parameter dict to check if the values match

    Returns
    -------
    Dict[in, Dict[str, Any]]
        filtered data dict
    """
    if type(value) is not tuple:
        value = (value,)

    if type(fields) is not list:
        fields = [fields]

    return dict(
        (k, v)
        for (k, v) in data.items()
        if tuple(map(lambda f: v["params"][f], fields)) == value
    )


class TablePrinter:
    """
    Print a table using UTF-8 symbols.
    """

    def __init__(self):
        """
        Create a new table printer.

        This function initializes border symbols.
        Important to its usage are the `formatter` and the `color_formatter`
        dicts. The `formatter` dict defines how columns or certain types should
        be formatted. E.g., floats are rounded to five digits per default.
        The `color_formatter` allows to format values per column, e.g., changing
        its color.
        """
        self.vert = ""
        self.hori = ""

        self.t_up = ""
        self.t_down = ""
        self.t_right = ""
        self.t_left = ""

        self.top_left = ""
        self.top_right = ""
        self.bottom_right = ""
        self.bottom_left = ""

        self.cross = ""

        self.formatter: Dict[Union[type, str], str] = {
            float: "{:.5f}",
            np.float64: "{:.5f}",
        }
        self.color_formatter: Dict[str, Callable[str, str]] = {}

    def print(self, table: Dict[str, List[Any]]):
        """
        Print a table.

        Parameters
        ----------
        table: Dict[str, List[Any]]
            table in column representation - a key is the column header and the list
            its values
        """
        headers = list(table.keys())

        table = dict(
            [
                (header, list(map(lambda value: self.format(value, header), column)))
                for header, column in table.items()
            ]
        )
            [
                (header, max(len(header), *map(len, column)))
                for header, column in table.items()
            ]
        )

        line_top = f"{self.vert}{self.t_down}{self.vert}".join(
            map(lambda header: self.vert * lengthes[header], headers)
        line_top = self.top_left + self.vert + line_top + self.vert + self.top_right
        print(line_top)

        line_headers = f" {self.hori} ".join(
            map(lambda header: f"{header:>{lengthes[header]}}", table.keys())
        line_headers = self.hori + " " + line_headers + " " + self.hori
        print(line_headers)

        line_mid = f"{self.vert}{self.cross}{self.vert}".join(
            map(lambda header: self.vert * lengthes[header], headers)
        line_mid = self.t_left + self.vert + line_mid + self.vert + self.t_right
        print(line_mid)

        for i in range(len(table[headers[0]])):
            values = map(
                lambda header: self.color(
                    f"{table[header][i]:>{lengthes[header]}}", header
            line = f" {self.hori} ".join(values)
            line = self.hori + " " + line + " " + self.hori
            print(line)

        line_bot = f"{self.vert}{self.t_up}{self.vert}".join(
            map(lambda header: self.vert * lengthes[header], headers)
        )
        line_bot = (
            self.bottom_left + self.vert + line_bot + self.vert + self.bottom_right
        )
    def format(self, value: Any, header: str) -> str:
        """
        Internal method to use the internal formatter dict.

        Parameters
        ----------
        value: Any
            the value to format
        header: str
            the column name of the value

        Returns
        -------
        str
            the formatted value
        """
        if header in self.formatter:
            return self.formatter[header].format(value)

        if type(value) in self.formatter:
            return self.formatter[type(value)].format(value)

        return str(value)

    def color(self, value_str: str, header: str) -> str:
        """
        Internal method to use the internal color_formatter dict.

        Parameters
        ----------
        value: str
            the value to format
        header: str
            the column name of the value

        Returns
        -------
        str
            the formatted value
        """
        if header in self.color_formatter:
            return self.color_formatter[header](value_str)
        return value_str