Skip to content
Snippets Groups Projects
util.py 6.16 KiB
Newer Older
  • Learn to ignore specific revisions
  • import json
    import os
    from typing import Any, Callable, Dict, List, Union, Tuple
    
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    
    from mu_map.random_search.cgan import load_params
    
    
    SIZE_DEFAULT = 12
    plt.rc("font", family="Roboto")  # controls default font
    plt.rc("font", weight="normal")  # controls default font
    plt.rc("font", size=SIZE_DEFAULT)  # controls default text sizes
    plt.rc("axes", titlesize=18)  # fontsize of the axes title
    
    class ColorList:
    
        def __init__(self, *colors):
            self.colors = colors
    
        def __len__(self):
            return len(self.colors)
    
        def __getitem__(self, i):
            return self.colors[i % len(self)]
    
    color_lists = {
        "default": ColorList(
            "#1f78b4",
            "#33a02c",
            "#e31a1c",
            "#ff7f00",
            "#cab2d6",
            "#a6cee3",
            "#b2df8a",
            "#fb9a99",
            "#fdbf6f",
            "#6a3d9a",
        ),
        "printer_friendly": ColorList(
            "#1b9e77",
            "#d95f02",
            "#7570b3",
            "#e7298a",
            "#66a61e",
            "#e6ab02",
        )
    }
    
    # https://colorbrewer2.org/#type=qualitative&scheme=Dark2&n=5
    short_color_list = ColorList(
        "#1b9e77",
        "#d95f02",
        "#7570b3",
        "#e7298a",
        "#66a61e",
    )
    
    # https://colorbrewer2.org/#type=qualitative&scheme=Set3&n=10
    COLORS = [
        "#8dd3c7",
        "#fb8072",
        "#80b1d3",
        "#fdb462",
        "#b3de69",
        "#fccde5",
        "#d9d9d9",
        "#bc80bd",
        "#ffffb3",
        "#bebada",
    ]
    
    
    def jitter(data: np.ndarray, amount: float = 0.1) -> np.ndarray:
        """
        Jitter the all values in an array.
    
        This is useful to scatter values which are all displayed for
        the same x value. The amount should be chosen in relation to
        the values in the data. For example, if the smallest change
        in x is 1 the amount should be lower than this.
    
        Parameters
        ----------
        data: np.ndarray
            the data which is jittered
        amount: float
            the maximal value added to the data for jittering
        Returns
        -------
        np.ndarray
        """
        return data + (np.random.rand(*data.shape) - 0.5) * amount
    
    
    def load_data(
        dir_random_search: str,
        file_measures: str = "measures.csv",
        file_params: str = "params.json",
    ) -> Dict[int, Dict[str, Any]]:
        dirs_run = sorted(os.listdir(dir_random_search))
        dirs_run = map(lambda f: os.path.join(dir_random_search, f), dirs_run)
        dirs_run = filter(lambda f: os.path.isdir(f), dirs_run)
        dirs_run = filter(lambda f: not os.path.islink(f), dirs_run)
        dirs_run = map(lambda f: os.path.basename(f), dirs_run)
    
        data = {}
        for dir_run in dirs_run:
            measures = pd.read_csv(os.path.join(dir_random_search, dir_run, file_measures))
            params = load_params(os.path.join(dir_random_search, dir_run, file_params))
            data[int(dir_run)] = {"measures": measures, "params": params, "dir": dir_run}
        return data
    
    
    def remove_outliers(
        data: Dict[int, Dict[str, Any]], file_outliers: str = "outliers.csv"
    ):
        outlier_runs = pd.read_csv(file_outliers)
        outlier_runs = outlier_runs[outlier_runs["outlier"]]
        outlier_runs = list(outlier_runs["run"])
        return dict(filter(lambda i: i[0] not in outlier_runs, data.items()))
    
    
    def filter_by_params(
        data: Dict[int, Dict[str, Any]],
        value: Union[Any, Tuple[Any]],
        fields: Union[str, List[str]],
    ):
        if type(value) is not tuple:
            value = (value,)
    
        if type(fields) is not list:
            fields = [fields]
    
        return dict(
            (k, v)
            for (k, v) in data.items()
            if tuple(map(lambda f: v["params"][f], fields)) == value
        )
    
    
    class TablePrinter():
    
        def __init__(self):
            self.vert = ""
            self.hori = ""
    
            self.t_up = ""
            self.t_down = ""
            self.t_right = ""
            self.t_left = ""
    
            self.top_left = ""
            self.top_right = ""
            self.bottom_right = ""
            self.bottom_left = ""
    
            self.cross = ""
    
            self.formatter = {
                float: "{:.5f}",
                np.float64: "{:.5f}",
            }
            self.color_formatter = {}
    
        def print(self, table: Dict[str, List[Any]]):
            headers = list(table.keys())
    
            table = dict([(header, list(map(lambda value: self.format(value, header), column))) for header, column in table.items()])
            lenghtes = dict([(header, max(len(header), *map(len, column))) for header, column in table.items()])
    
            line_top = f"{self.vert}{self.t_down}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
            line_top = self.top_left + self.vert + line_top + self.vert + self.top_right
            print(line_top)
    
            line_headers = f" {self.hori} ".join(map(lambda header: f"{header:>{lenghtes[header]}}", table.keys()))
            line_headers = self.hori + " " + line_headers + " " + self.hori
            print(line_headers)
    
            line_mid = f"{self.vert}{self.cross}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
            line_mid = self.t_left + self.vert + line_mid + self.vert + self.t_right
            print(line_mid)
    
            for i in range(len(table[headers[0]])):
                values = map(lambda header: self.color(f"{table[header][i]:>{lenghtes[header]}}", header), headers)
                line = f" {self.hori} ".join(values)
                line = self.hori + " " + line + " " + self.hori
                print(line)
    
            line_bot = f"{self.vert}{self.t_up}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers))
            line_bot = self.bottom_left + self.vert + line_bot + self.vert + self.bottom_right
            print(line_bot)
    
    
        def format(self, value: Any, header: str):
            if header in self.formatter:
                return self.formatter[header].format(value)
    
            if type(value) in self.formatter:
                return self.formatter[type(value)].format(value)
    
            return str(value)
    
        def color(self, value_str: str, header: str):
            if header in self.color_formatter:
                return self.color_formatter[header](value_str)
            return value_str
    
    if __name__ == "__main__":
        import argparse
    
        parser = argparse.ArgumentParser()
        parser.add_argument("random_search_dir", type=str)
        args = parser.parse_args()
    
        data = load_data(args.random_search_dir)