import json import os from typing import Any, Callable, Dict, List, Union, Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd from mu_map.random_search.cgan import load_params SIZE_DEFAULT = 12 plt.rc("font", family="Roboto") # controls default font plt.rc("font", weight="normal") # controls default font plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes plt.rc("axes", titlesize=18) # fontsize of the axes title class ColorList: def __init__(self, *colors): self.colors = colors def __len__(self): return len(self.colors) def __getitem__(self, i): return self.colors[i % len(self)] color_lists = { "default": ColorList( "#1f78b4", "#33a02c", "#e31a1c", "#ff7f00", "#cab2d6", "#a6cee3", "#b2df8a", "#fb9a99", "#fdbf6f", "#6a3d9a", ), "printer_friendly": ColorList( "#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e", "#e6ab02", ) } # https://colorbrewer2.org/#type=qualitative&scheme=Dark2&n=5 short_color_list = ColorList( "#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e", ) # https://colorbrewer2.org/#type=qualitative&scheme=Set3&n=10 COLORS = [ "#8dd3c7", "#fb8072", "#80b1d3", "#fdb462", "#b3de69", "#fccde5", "#d9d9d9", "#bc80bd", "#ffffb3", "#bebada", ] def jitter(data: np.ndarray, amount: float = 0.1) -> np.ndarray: """ Jitter the all values in an array. This is useful to scatter values which are all displayed for the same x value. The amount should be chosen in relation to the values in the data. For example, if the smallest change in x is 1 the amount should be lower than this. Parameters ---------- data: np.ndarray the data which is jittered amount: float the maximal value added to the data for jittering Returns ------- np.ndarray """ return data + (np.random.rand(*data.shape) - 0.5) * amount def load_data( dir_random_search: str, file_measures: str = "measures.csv", file_params: str = "params.json", ) -> Dict[int, Dict[str, Any]]: dirs_run = sorted(os.listdir(dir_random_search)) dirs_run = map(lambda f: os.path.join(dir_random_search, f), dirs_run) dirs_run = filter(lambda f: os.path.isdir(f), dirs_run) dirs_run = filter(lambda f: not os.path.islink(f), dirs_run) dirs_run = map(lambda f: os.path.basename(f), dirs_run) data = {} for dir_run in dirs_run: measures = pd.read_csv(os.path.join(dir_random_search, dir_run, file_measures)) params = load_params(os.path.join(dir_random_search, dir_run, file_params)) data[int(dir_run)] = {"measures": measures, "params": params, "dir": dir_run} return data def remove_outliers( data: Dict[int, Dict[str, Any]], file_outliers: str = "outliers.csv" ): outlier_runs = pd.read_csv(file_outliers) outlier_runs = outlier_runs[outlier_runs["outlier"]] outlier_runs = list(outlier_runs["run"]) return dict(filter(lambda i: i[0] not in outlier_runs, data.items())) def filter_by_params( data: Dict[int, Dict[str, Any]], value: Union[Any, Tuple[Any]], fields: Union[str, List[str]], ): if type(value) is not tuple: value = (value,) if type(fields) is not list: fields = [fields] return dict( (k, v) for (k, v) in data.items() if tuple(map(lambda f: v["params"][f], fields)) == value ) class TablePrinter(): def __init__(self): self.vert = "─" self.hori = "│" self.t_up = "┴" self.t_down = "┬" self.t_right = "┤" self.t_left = "├" self.top_left = "┌" self.top_right = "┐" self.bottom_right = "┘" self.bottom_left = "└" self.cross = "┼" self.formatter = { float: "{:.5f}", np.float64: "{:.5f}", } self.color_formatter = {} def print(self, table: Dict[str, List[Any]]): headers = list(table.keys()) table = dict([(header, list(map(lambda value: self.format(value, header), column))) for header, column in table.items()]) lenghtes = dict([(header, max(len(header), *map(len, column))) for header, column in table.items()]) line_top = f"{self.vert}{self.t_down}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers)) line_top = self.top_left + self.vert + line_top + self.vert + self.top_right print(line_top) line_headers = f" {self.hori} ".join(map(lambda header: f"{header:>{lenghtes[header]}}", table.keys())) line_headers = self.hori + " " + line_headers + " " + self.hori print(line_headers) line_mid = f"{self.vert}{self.cross}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers)) line_mid = self.t_left + self.vert + line_mid + self.vert + self.t_right print(line_mid) for i in range(len(table[headers[0]])): values = map(lambda header: self.color(f"{table[header][i]:>{lenghtes[header]}}", header), headers) line = f" {self.hori} ".join(values) line = self.hori + " " + line + " " + self.hori print(line) line_bot = f"{self.vert}{self.t_up}{self.vert}".join(map(lambda header: self.vert * lenghtes[header], headers)) line_bot = self.bottom_left + self.vert + line_bot + self.vert + self.bottom_right print(line_bot) def format(self, value: Any, header: str): if header in self.formatter: return self.formatter[header].format(value) if type(value) in self.formatter: return self.formatter[type(value)].format(value) return str(value) def color(self, value_str: str, header: str): if header in self.color_formatter: return self.color_formatter[header](value_str) return value_str if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("random_search_dir", type=str) args = parser.parse_args() data = load_data(args.random_search_dir)