Skip to content
Snippets Groups Projects
Commit fc46a5cb authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

remove plot title, tight layout and make definitions of labels optional

parent 3157e7c6
No related branches found
No related tags found
No related merge requests found
......@@ -8,6 +8,10 @@ import numpy as np
from scipy import stats
from termcolor import colored
plt.rcParams.update({
"text.usetex": True,
})
from mu_map.dataset.normalization import (
GaussianNormTransform,
MeanNormTransform,
......@@ -41,6 +45,12 @@ class ParameterGroups:
groups: Dict[str, Any]
keys: List[str]
labels: List[str] = None
def get_labels(self):
if self.labels is None:
return list(self.groups.keys())
return self.labels
"""
......@@ -62,7 +72,7 @@ parameter_groups = {
keys=["discriminator_type", "discriminator_conv_features"],
),
"discriminator_type": ParameterGroups(
groups={"Class": "class", "PatchGAN": "patch"}, keys=["discriminator_type"]
groups={"Class": "class", "PatchGAN": "patch"}, keys=["discriminator_type"],
),
"distance_loss": ParameterGroups(
groups={
......@@ -70,6 +80,7 @@ parameter_groups = {
"L2 + GDL": WeightedLoss.from_str("L2+GDL"),
},
keys=["criterion_dist"],
labels=[r"$L_{1}$", r"$L_{2} + L_{GDL}$"],
),
"loss_weights": ParameterGroups(
groups={
......@@ -102,7 +113,6 @@ def plot_param_groups(
data: Dict[str, Dict[str, Any]],
param_groups: ParameterGroups,
measure: str = "NMAE",
title: Optional[str] = None,
colors: ColorList = color_lists["printer_friendly"],
):
"""
......@@ -121,8 +131,6 @@ def plot_param_groups(
well as how to access
measure: str,
the measure plotted
title: str, optional
the title given to the plot
colors: ColorList,
a color list object defining which colors to use for different parameter
manifestations
......@@ -132,9 +140,6 @@ def plot_param_groups(
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
if title is not None:
ax.set_title(title)
y_max = 0
for i, (label, value) in enumerate(param_groups.groups.items()):
_data = filter_by_params(data, value, param_groups.keys)
......@@ -159,9 +164,10 @@ def plot_param_groups(
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
xticks = list(range(1, len(param_groups.groups) + 1))
labels = param_groups.get_labels()
xticks = list(range(1, len(labels) + 1))
ax.set_xticks(xticks)
ax.set_xticklabels(list(param_groups.groups.keys()))
ax.set_xticklabels(labels)
ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5)
......@@ -238,6 +244,11 @@ if __name__ == "__main__":
description="Compare different values for a parameter by plotting and performing statistical tests",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"param",
choices=list(parameter_groups.keys()),
help="the parameter to analyse",
)
parser.add_argument(
"--random_search_dir",
type=str,
......@@ -250,11 +261,6 @@ if __name__ == "__main__":
default="outliers.csv",
help="optional file defining outliers / runs to be ignored for analysis",
)
parser.add_argument(
"--param",
choices=list(parameter_groups.keys()),
help="the parameter to analyse",
)
parser.add_argument(
"--save",
type=str,
......@@ -281,19 +287,15 @@ if __name__ == "__main__":
)
param_groups = parameter_groups[args.param]
plot_title = " ".join(
map(lambda _str: _str[0].upper() + _str[1:], args.param.split("_"))
)
analyse_stats(data, param_groups, measure=args.measure)
plot_param_groups(
data,
param_groups,
measure=args.measure,
title=plot_title,
colors=color_lists[args.colors],
)
plt.tight_layout()
if args.save:
plt.savefig(args.save, dpi=300)
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment