""" Utility functions and classes for the random search evaluation. """ import json import os from typing import Any, Callable, Dict, List, Union, Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd from mu_map.random_search.cgan import load_params SIZE_DEFAULT = 16 plt.rc("font", family="Roboto") # controls default font plt.rc("font", weight="normal") # controls default font plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes # plt.rc("axes", titlesize=18) # fontsize of the axes title plt.rc("axes", titlesize=22) # fontsize of the axes title class ColorList: """ Class wrapping a list so that indexing values wraps around. """ def __init__(self, *colors: str): self.colors = colors def __len__(self): return len(self.colors) def __getitem__(self, i): return self.colors[i % len(self)] """ Two examples of color lists defined with the help of: https://colorbrewer2.org """ color_lists = { "default": ColorList( "#1f78b4", "#33a02c", "#e31a1c", "#ff7f00", "#cab2d6", "#a6cee3", "#b2df8a", "#fb9a99", "#fdbf6f", "#6a3d9a", ), "printer_friendly": ColorList( "#1b9e77", "#d95f02", "#7570b3", "#e7298a", "#66a61e", "#e6ab02", ), } def jitter(data: np.ndarray, amount: float = 0.1) -> np.ndarray: """ Jitter the values in an array. This is useful to scatter values which are all displayed for the same x value. The amount should be chosen in relation to the values in the data. For example, if the smallest change in x is 1 the amount should be lower than this. Parameters ---------- data: np.ndarray the data which is jittered amount: float the maximal value added to the data for jittering Returns ------- np.ndarray """ return data + (np.random.rand(*data.shape) - 0.5) * amount def load_data( dir_random_search: str, file_measures: str = "measures.csv", file_params: str = "params.json", ) -> Dict[int, Dict[str, Any]]: """ Load results of a random search procedure into a dict. The dict is the mapping of an iteration id to its directory, the measures CSV file and the params JSON file. Parameters ---------- dir_random_search: str the directory of the random search procedure it is expected that each iteration is stored in a separate numbered directory file_measures: str, optional filename of the measures CSV file in each directory file_params: str, optional filename of the params JSON file in each directory Returns ------- Dict[int, Dict[str, Any]] a dict mapping the iteration number to "measures", "params", and "dir" """ dirs_run = sorted(os.listdir(dir_random_search)) dirs_run = filter(lambda f: f.isdigit(), dirs_run) dirs_run = map(lambda f: os.path.join(dir_random_search, f), dirs_run) dirs_run = filter(lambda f: os.path.isdir(f), dirs_run) dirs_run = map(lambda f: os.path.basename(f), dirs_run) data = {} for dir_run in dirs_run: measures = pd.read_csv(os.path.join(dir_random_search, dir_run, file_measures)) params = load_params(os.path.join(dir_random_search, dir_run, file_params)) data[int(dir_run)] = {"measures": measures, "params": params, "dir": dir_run} return data def remove_outliers( data: Dict[int, Dict[str, Any]], file_outliers: str = "outliers.csv" ) -> Dict[int, Dict[str, Any]]: """ Remove outlier iterations from data loaded with `load_data`. Parameters ---------- data: Dict[int, Dict[str, Any]] data loaded by `load_data` file_outliers: str CSV file defining outliers (as generated by `mu_map.random_search.eval.label_outliers`) Returns ------- Dict[in, Dict[str, Any]] filtered data dict """ outlier_runs = pd.read_csv(file_outliers) outlier_runs = outlier_runs[outlier_runs["outlier"]] outlier_runs = list(outlier_runs["run"]) return dict(filter(lambda i: i[0] not in outlier_runs, data.items())) def filter_by_params( data: Dict[int, Dict[str, Any]], value: Union[Any, Tuple[Any]], fields: Union[str, List[str]], ) -> Dict[int, Dict[str, Any]]: """ Filter data loaded with `load_data` based on certain parameters. Parameters ---------- data: Dict[int, Dict[str, Any]] data loaded by `load_data` value: Any or tuple of Any values for the parameters fields: str or list of str field names to access the parameter dict to check if the values match Returns ------- Dict[in, Dict[str, Any]] filtered data dict """ if type(value) is not tuple: value = (value,) if type(fields) is not list: fields = [fields] return dict( (k, v) for (k, v) in data.items() if tuple(map(lambda f: v["params"][f], fields)) == value ) class TablePrinter: """ Print a table using UTF-8 symbols. """ def __init__(self): """ Create a new table printer. This function initializes border symbols. Important to its usage are the `formatter` and the `color_formatter` dicts. The `formatter` dict defines how columns or certain types should be formatted. E.g., floats are rounded to five digits per default. The `color_formatter` allows to format values per column, e.g., changing its color. """ self.vert = "─" self.hori = "│" self.t_up = "┴" self.t_down = "┬" self.t_right = "┤" self.t_left = "├" self.top_left = "┌" self.top_right = "┐" self.bottom_right = "┘" self.bottom_left = "└" self.cross = "┼" self.formatter: Dict[Union[type, str], str] = { float: "{:.5f}", np.float64: "{:.5f}", } self.color_formatter: Dict[str, Callable[str, str]] = {} def print(self, table: Dict[str, List[Any]]): """ Print a table. Parameters ---------- table: Dict[str, List[Any]] table in column representation - a key is the column header and the list its values """ headers = list(table.keys()) table = dict( [ (header, list(map(lambda value: self.format(value, header), column))) for header, column in table.items() ] ) lengthes = dict( [ (header, max(len(header), *map(len, column))) for header, column in table.items() ] ) line_top = f"{self.vert}{self.t_down}{self.vert}".join( map(lambda header: self.vert * lengthes[header], headers) ) line_top = self.top_left + self.vert + line_top + self.vert + self.top_right print(line_top) line_headers = f" {self.hori} ".join( map(lambda header: f"{header:>{lengthes[header]}}", table.keys()) ) line_headers = self.hori + " " + line_headers + " " + self.hori print(line_headers) line_mid = f"{self.vert}{self.cross}{self.vert}".join( map(lambda header: self.vert * lengthes[header], headers) ) line_mid = self.t_left + self.vert + line_mid + self.vert + self.t_right print(line_mid) for i in range(len(table[headers[0]])): values = map( lambda header: self.color( f"{table[header][i]:>{lengthes[header]}}", header ), headers, ) line = f" {self.hori} ".join(values) line = self.hori + " " + line + " " + self.hori print(line) line_bot = f"{self.vert}{self.t_up}{self.vert}".join( map(lambda header: self.vert * lengthes[header], headers) ) line_bot = ( self.bottom_left + self.vert + line_bot + self.vert + self.bottom_right ) print(line_bot) def format(self, value: Any, header: str) -> str: """ Internal method to use the internal formatter dict. Parameters ---------- value: Any the value to format header: str the column name of the value Returns ------- str the formatted value """ if header in self.formatter: return self.formatter[header].format(value) if type(value) in self.formatter: return self.formatter[type(value)].format(value) return str(value) def color(self, value_str: str, header: str) -> str: """ Internal method to use the internal color_formatter dict. Parameters ---------- value: str the value to format header: str the column name of the value Returns ------- str the formatted value """ if header in self.color_formatter: return self.color_formatter[header](value_str) return value_str