Newer
Older
"""
Utility functions and classes for the random search evaluation.
"""
import json
import os
from typing import Any, Callable, Dict, List, Union, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mu_map.random_search.cgan import load_params
SIZE_DEFAULT = 12
plt.rc("font", family="Roboto") # controls default font
plt.rc("font", weight="normal") # controls default font
plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes
plt.rc("axes", titlesize=18) # fontsize of the axes title
"""
Class wrapping a list so that indexing values wraps around.
"""
def __init__(self, *colors: str):
self.colors = colors
def __len__(self):
return len(self.colors)
def __getitem__(self, i):
return self.colors[i % len(self)]
"""
Two examples of color lists defined with the help of:
https://colorbrewer2.org
"""
color_lists = {
"default": ColorList(
"#1f78b4",
"#33a02c",
"#e31a1c",
"#ff7f00",
"#cab2d6",
"#a6cee3",
"#b2df8a",
"#fb9a99",
"#fdbf6f",
"#6a3d9a",
),
"printer_friendly": ColorList(
"#1b9e77",
"#d95f02",
"#7570b3",
"#e7298a",
"#66a61e",
"#e6ab02",
}
def jitter(data: np.ndarray, amount: float = 0.1) -> np.ndarray:
"""
Jitter the values in an array.
This is useful to scatter values which are all displayed for
the same x value. The amount should be chosen in relation to
the values in the data. For example, if the smallest change
in x is 1 the amount should be lower than this.
Parameters
----------
data: np.ndarray
the data which is jittered
amount: float
the maximal value added to the data for jittering
Returns
-------
np.ndarray
"""
return data + (np.random.rand(*data.shape) - 0.5) * amount
def load_data(
dir_random_search: str,
file_measures: str = "measures.csv",
file_params: str = "params.json",
) -> Dict[int, Dict[str, Any]]:
"""
Load results of a random search procedure into a dict.
The dict is the mapping of an iteration id to its directory,
the measures CSV file and the params JSON file.
Parameters
----------
dir_random_search: str
the directory of the random search procedure
it is expected that each iteration is stored in a separate numbered directory
file_measures: str, optional
filename of the measures CSV file in each directory
file_params: str, optional
filename of the params JSON file in each directory
Returns
-------
Dict[int, Dict[str, Any]]
a dict mapping the iteration number to "measures", "params", and "dir"
"""
dirs_run = sorted(os.listdir(dir_random_search))
dirs_run = map(lambda f: os.path.join(dir_random_search, f), dirs_run)
dirs_run = filter(lambda f: os.path.isdir(f), dirs_run)
dirs_run = filter(lambda f: not os.path.islink(f), dirs_run)
dirs_run = map(lambda f: os.path.basename(f), dirs_run)
data = {}
for dir_run in dirs_run:
measures = pd.read_csv(os.path.join(dir_random_search, dir_run, file_measures))
params = load_params(os.path.join(dir_random_search, dir_run, file_params))
data[int(dir_run)] = {"measures": measures, "params": params, "dir": dir_run}
return data
def remove_outliers(
data: Dict[int, Dict[str, Any]], file_outliers: str = "outliers.csv"
) -> Dict[int, Dict[str, Any]]:
"""
Remove outlier iterations from data loaded with `load_data`.
Parameters
----------
data: Dict[int, Dict[str, Any]]
data loaded by `load_data`
file_outliers: str
CSV file defining outliers (as generated by `mu_map.random_search.eval.label_outliers`)
Returns
-------
Dict[in, Dict[str, Any]]
filtered data dict
"""
outlier_runs = pd.read_csv(file_outliers)
outlier_runs = outlier_runs[outlier_runs["outlier"]]
outlier_runs = list(outlier_runs["run"])
return dict(filter(lambda i: i[0] not in outlier_runs, data.items()))
def filter_by_params(
data: Dict[int, Dict[str, Any]],
value: Union[Any, Tuple[Any]],
fields: Union[str, List[str]],
) -> Dict[int, Dict[str, Any]]:
"""
Filter data loaded with `load_data` based on certain parameters.
Parameters
----------
data: Dict[int, Dict[str, Any]]
data loaded by `load_data`
value: Any or tuple of Any
values for the parameters
fields: str or list of str
field names to access the parameter dict to check if the values match
Returns
-------
Dict[in, Dict[str, Any]]
filtered data dict
"""
if type(value) is not tuple:
value = (value,)
if type(fields) is not list:
fields = [fields]
return dict(
(k, v)
for (k, v) in data.items()
if tuple(map(lambda f: v["params"][f], fields)) == value
)
class TablePrinter:
"""
Print a table using UTF-8 symbols.
"""
"""
Create a new table printer.
This function initializes border symbols.
Important to its usage are the `formatter` and the `color_formatter`
dicts. The `formatter` dict defines how columns or certain types should
be formatted. E.g., floats are rounded to five digits per default.
The `color_formatter` allows to format values per column, e.g., changing
its color.
"""
self.vert = "─"
self.hori = "│"
self.t_up = "┴"
self.t_down = "┬"
self.t_right = "┤"
self.t_left = "├"
self.top_left = "┌"
self.top_right = "┐"
self.bottom_right = "┘"
self.bottom_left = "└"
self.cross = "┼"
self.formatter: Dict[Union[type, str], str] = {
float: "{:.5f}",
np.float64: "{:.5f}",
}
self.color_formatter: Dict[str, Callable[str, str]] = {}
def print(self, table: Dict[str, List[Any]]):
"""
Print a table.
Parameters
----------
table: Dict[str, List[Any]]
table in column representation - a key is the column header and the list
its values
"""
table = dict(
[
(header, list(map(lambda value: self.format(value, header), column)))
for header, column in table.items()
]
)
lenghtes = dict(
[
(header, max(len(header), *map(len, column)))
for header, column in table.items()
]
)
line_top = f"{self.vert}{self.t_down}{self.vert}".join(
map(lambda header: self.vert * lenghtes[header], headers)
)
line_top = self.top_left + self.vert + line_top + self.vert + self.top_right
print(line_top)
line_headers = f" {self.hori} ".join(
map(lambda header: f"{header:>{lenghtes[header]}}", table.keys())
)
line_headers = self.hori + " " + line_headers + " " + self.hori
print(line_headers)
line_mid = f"{self.vert}{self.cross}{self.vert}".join(
map(lambda header: self.vert * lenghtes[header], headers)
)
line_mid = self.t_left + self.vert + line_mid + self.vert + self.t_right
print(line_mid)
for i in range(len(table[headers[0]])):
values = map(
lambda header: self.color(
f"{table[header][i]:>{lenghtes[header]}}", header
),
headers,
)
line = f" {self.hori} ".join(values)
line = self.hori + " " + line + " " + self.hori
print(line)
line_bot = f"{self.vert}{self.t_up}{self.vert}".join(
map(lambda header: self.vert * lenghtes[header], headers)
)
line_bot = (
self.bottom_left + self.vert + line_bot + self.vert + self.bottom_right
)
def format(self, value: Any, header: str) -> str:
"""
Internal method to use the internal formatter dict.
Parameters
----------
value: Any
the value to format
header: str
the column name of the value
Returns
-------
str
the formatted value
"""
if header in self.formatter:
return self.formatter[header].format(value)
if type(value) in self.formatter:
return self.formatter[type(value)].format(value)
return str(value)
def color(self, value_str: str, header: str) -> str:
"""
Internal method to use the internal color_formatter dict.
Parameters
----------
value: str
the value to format
header: str
the column name of the value
Returns
-------
str
the formatted value
"""
if header in self.color_formatter:
return self.color_formatter[header](value_str)
return value_str