Skip to content
Snippets Groups Projects
Commit efb1b34a authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

mu_map/random_search/eval/params.py allows the evaluation of multiple params with the same call

parent e6dca898
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import itertools
import os
from typing import Any, Callable, Dict, List, Optional
import matplotlib as mlp
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats
......@@ -121,6 +122,7 @@ def plot_param_groups(
param_groups: ParameterGroups,
measure: str = "NMAE",
colors: ColorList = color_lists["printer_friendly"],
ax: mlp.axes.Axes = None,
):
"""
Create a plot to visually compare all manifestations of a parameter.
......@@ -142,10 +144,11 @@ def plot_param_groups(
a color list object defining which colors to use for different parameter
manifestations
"""
fig_width = 6 + len(param_groups.groups) - 2
fig_height = 4
if ax is None:
fig_width = 6 + len(param_groups.groups) - 2
fig_height = 4
fig, ax = plt.subplots(figsize=(fig_width, fig_height))
_, ax = plt.subplots(figsize=(fig_width, fig_height))
y_max = 0
for i, (label, value) in enumerate(param_groups.groups.items()):
......@@ -254,6 +257,7 @@ if __name__ == "__main__":
parser.add_argument(
"param",
choices=list(parameter_groups.keys()),
nargs="+",
help="the parameter to analyse",
)
parser.add_argument(
......@@ -293,16 +297,38 @@ if __name__ == "__main__":
data, os.path.join(args.random_search_dir, args.outliers_file)
)
param_groups = parameter_groups[args.param]
analyse_stats(data, param_groups, measure=args.measure)
plot_param_groups(
data,
param_groups,
measure=args.measure,
colors=color_lists[args.colors],
)
axs = [None]
if len(args.param) > 1:
_, axs = plt.subplots(1, len(args.param), figsize=(1 + len(args.param) * 4, 4))
print(type(axs[0]))
for i, (_param, ax) in enumerate(zip(args.param, axs)):
title = " ".join(map(lambda s: s[0].upper() + s[1:], _param.split("_")))
param_groups = parameter_groups[_param]
print()
print(title)
analyse_stats(data, param_groups, measure=args.measure)
plot_param_groups(
data,
param_groups,
measure=args.measure,
colors=color_lists[args.colors],
ax=ax,
)
if len(args.param) > 1:
ax.set_title(title)
if i == 0:
continue
ax.set_ylabel("")
ax.set_yticklabels([])
plt.tight_layout()
plt.subplots_adjust(wspace=0.15)
if args.save:
plt.savefig(args.save, dpi=300)
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment