Skip to content
Snippets Groups Projects
Commit 89f91450 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

Add file for parameter evaluation in random search

parent 60138094
No related branches found
No related tags found
No related merge requests found
...@@ -9,7 +9,7 @@ from mu_map.dataset.default import MuMapDataset ...@@ -9,7 +9,7 @@ from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.transform import SequenceTransform, PadCropTranform from mu_map.dataset.transform import SequenceTransform, PadCropTranform
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
from mu_map.random_search.cgan import load_params from mu_map.random_search.cgan import load_params
from mu_map.random_search.show_predictions import main from mu_map.random_search.eval.show_predictions import main
from mu_map.random_search.eval.util import load_data from mu_map.random_search.eval.util import load_data
controls = """ controls = """
...@@ -71,7 +71,6 @@ for i, run in enumerate(runs): ...@@ -71,7 +71,6 @@ for i, run in enumerate(runs):
data["run"].append(int(run)) data["run"].append(int(run))
data["outlier"].append(False) data["outlier"].append(False)
dir_run = os.path.join(args.random_search_dir, runs[run]["dir"]) dir_run = os.path.join(args.random_search_dir, runs[run]["dir"])
params = runs[run]["params"] params = runs[run]["params"]
......
from dataclasses import dataclass
import itertools
import os
from typing import Any, Callable, Dict, List, Optional
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
from termcolor import colored
from mu_map.dataset.normalization import (
GaussianNormTransform,
MeanNormTransform,
MaxNormTransform,
)
from mu_map.random_search.eval.util import (
ColorList,
color_lists,
load_data,
remove_outliers,
filter_by_params,
jitter,
TablePrinter,
)
from mu_map.training.loss import WeightedLoss
@dataclass
class ParameterGroups:
"""
Dataclass describing all manifestations of a parameter
to compare.
Parameters
----------
groups: Dict[str, Any]
mappings of labels to manifestations of a parameter
keys: List[str]
keys to access to parameter manifestations from a parsed json file
"""
groups: Dict[str, Any]
keys: List[str]
"""
Definitions of parameters and their manifestations.
"""
parameter_groups = {
"generator_depth": ParameterGroups(
groups={"Small": [64, 128, 256, 512], "Large": [32, 64, 128, 256, 512]},
keys=["generator_features"],
),
"discriminator_depth": ParameterGroups(
groups={
"Class S": ("class", [32, 64, 128]),
"Class M": ("class", [64, 128, 256]),
"Class L": ("class", [32, 64, 128, 256]),
"PatchGAN M": ("patch", [32, 64, 128, 256]),
"PatchGAN L": ("patch", [64, 128, 256, 512]),
},
keys=["discriminator_type", "discriminator_conv_features"],
),
"discriminator_type": ParameterGroups(
groups={"Class": "class", "PatchGAN": "patch"}, keys=["discriminator_type"]
),
"distance_loss": ParameterGroups(
groups={
"L1": WeightedLoss.from_str("L1"),
"L2 + GDL": WeightedLoss.from_str("L2+GDL"),
},
keys=["criterion_dist"],
),
"loss_weights": ParameterGroups(
groups={
"100:1": (100, 1),
"20:1": (20, 1),
"5:1": (5, 1),
"1:1": (1, 1),
"1:5": (1, 5),
"1:20": (1, 20),
"1:100": (1, 100),
},
keys=["weight_crit_dist", "weight_crit_adv"],
),
"patch_size": ParameterGroups(groups={"32": 32, "64": 64}, keys=["patch_size"]),
"normalization": ParameterGroups(
groups={
"Gaussian": GaussianNormTransform(),
"Mean": MeanNormTransform(),
"Max": MaxNormTransform(),
},
keys=["normalization"],
),
"learning_rate_decay": ParameterGroups(
groups={"Decay": True, "No Decay": False}, keys=["lr_decay"]
),
}
def plot_param_groups(
data: Dict[str, Dict[str, Any]],
param_groups: ParameterGroups,
measure: str = "NMAE",
title: Optional[str] = None,
colors: ColorList = color_lists["printer_friendly"],
):
"""
Create a plot to visually compare all manifestations of a parameter.
The plot is a bar plot of the mean with an indicator of the standard deviation.
In addition, all values are scattered on top.
Parameters
----------
data: Dict[str, Dict[str, Any]]
data of a random search procedure as returned by the `load_data` method
in `mu_map.random_search.eval.util`
param_groups: ParameterGroups,
a parameter groups object defining all manifestations of a parameter as
well as how to access
measure: str,
the measure plotted
title: str, optional
the title given to the plot
colors: ColorList,
a color list object defining which colors to use for different parameter
manifestations
"""
fig_width = 6 + len(param_groups.groups) - 2
fig_height = 4
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
if title is not None:
ax.set_title(title)
y_max = 0
for i, (label, value) in enumerate(param_groups.groups.items()):
_data = filter_by_params(data, value, param_groups.keys)
ys = map(lambda run: _data[run]["measures"][measure].mean(), _data.keys())
ys = np.array(list(ys))
y_max = max(y_max, ys.max())
if len(ys) == 0:
continue
x = i + 1
ax.bar([x], ys.mean(), yerr=ys.std(), color=colors[i], capsize=5.0)
xs = jitter(np.full(ys.shape, x), amount=0.5)
ax.scatter(xs, ys, color=colors[i], alpha=0.4, edgecolor="black")
ax.set_ylabel(measure)
ax.set_ylim((0, y_max + 0.01))
ax.grid(axis="y", alpha=0.3)
ax.spines["left"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
xticks = list(range(1, len(param_groups.groups) + 1))
ax.set_xticks(xticks)
ax.set_xticklabels(list(param_groups.groups.keys()))
ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5)
def analyse_stats(
data: Dict[str, Dict[str, Any]],
param_groups: ParameterGroups,
measure: str = "NMAE",
):
"""
Perform a statistical analysis for the manifestation of a parameter.
Each manifestation is tested to be normally distributed and afterwards
pairs are compared to be different with the t-test.
Parameters
----------
data: Dict[str, Dict[str, Any]]
data of a random search procedure as returned by the `load_data` method
in `mu_map.random_search.eval.util`
param_groups: ParameterGroups,
a parameter groups object defining all manifestations of a parameter as
well as how to access
measure: str,
the measure analysed
"""
tp = TablePrinter()
tp.color_formatter["Normal"] = lambda _str: colored(
_str, color="green" if _str.strip().lower() == "yes" else None
)
tp.color_formatter["Significant"] = lambda _str: colored(
_str, color="green" if _str.strip().lower() == "yes" else None
)
tp.color_formatter["P Value"] = lambda _str: colored(
_str, color="green" if float(_str) < 0.05 else None
)
ys = {}
for label, param_value in param_groups.groups.items():
_data = filter_by_params(data, param_value, param_groups.keys)
_ys = map(lambda run: _data[run]["measures"][measure].mean(), _data.keys())
ys[label] = np.array(list(_ys))
table = {"Label": [], "Normal": [], "P Value": [], "Stat": [], "Mean±Std": []}
for i, (label, _ys) in enumerate(ys.items()):
if len(_ys) < 3:
print(f"Cannot evaluate {label} because of too little data n={len(_ys)}")
continue
stat, p_value = stats.shapiro(_ys)
table["Label"].append(label)
table["Normal"].append("YES" if p_value < 0.05 else "NO")
table["Stat"].append(stat)
table["P Value"].append(p_value)
table["Mean±Std"].append(f"{_ys.mean():.5f}±{_ys.std():.5f}")
tp.print(table)
print()
table = {"Label 1": [], "Label 2": [], "Significant": [], "Stat": [], "P Value": []}
for label_1, label_2 in itertools.combinations(ys.keys(), 2):
stat, p_value = stats.ttest_ind(ys[label_1], ys[label_2])
table["Label 1"].append(label_1)
table["Label 2"].append(label_2)
table["Significant"].append("YES" if p_value < 0.05 else " NO")
table["Stat"].append(stat)
table["P Value"].append(p_value)
tp.print(table)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Compare different values for a parameter by plotting and performing statistical tests",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--random_search_dir",
type=str,
default="cgan_random_search",
help="the directory containing the random search data",
)
parser.add_argument(
"--outliers_file",
type=str,
default="outliers.csv",
help="optional file defining outliers / runs to be ignored for analysis",
)
parser.add_argument(
"--param",
choices=list(parameter_groups.keys()),
help="the parameter to analyse",
)
parser.add_argument(
"--measure",
choices=["NMAE", "MSE"],
default="NMAE",
help="the measure used for plotting and analysis",
)
parser.add_argument(
"--colors",
choices=list(color_lists.keys()),
default="printer_friendly",
help="chose which colors to use for plotting",
)
args = parser.parse_args()
data = load_data(args.random_search_dir)
if args.outliers_file:
data = remove_outliers(
data, os.path.join(args.random_search_dir, args.outliers_file)
)
param_groups = parameter_groups[args.param]
plot_title = " ".join(
map(lambda _str: _str[0].upper() + _str[1:], args.param.split("_"))
)
analyse_stats(data, param_groups, measure=args.measure)
plot_param_groups(
data,
param_groups,
measure=args.measure,
title=plot_title,
colors=color_lists[args.colors],
)
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment