Skip to content
Snippets Groups Projects
Commit b68b2c72 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

restructure loss curve so that it may be called from another module

parent 2dbb2527
No related branches found
No related tags found
No related merge requests found
...@@ -5,54 +5,10 @@ import numpy as np ...@@ -5,54 +5,10 @@ import numpy as np
from mu_map.logging import parse_file from mu_map.logging import parse_file
SIZE_DEFAULT = 12
plt.rc("font", family="Roboto") # controls default font
plt.rc("font", weight="normal") # controls default font
plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes
plt.rc("axes", titlesize=18) # fontsize of the axes title
# https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk # https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk
COLORS = ["#ef8a62", "#67a9cf"] COLORS = ["#ef8a62", "#67a9cf"]
parser = argparse.ArgumentParser(
description="plot the losses written to a logfile",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"logfile",
type=str,
help="the logfile from which the losses (training and validation) are parsed",
)
parser.add_argument(
"--loss_types",
type=str,
nargs="+",
default=["train", "validation"],
help="the types of losses the log is searched for",
)
parser.add_argument(
"--normalize",
action="store_true",
help="normalize the loss values (both training and validation losses are normalized separately)",
)
parser.add_argument(
"--out",
"-o",
type=str,
default="loss.png",
help="the file into which the resulting plot is saved",
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="do not only save the figure but also attempt to visualize it (opens a window)",
)
args = parser.parse_args()
logs = parse_file(args.logfile)
logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
def parse_loss(logs, loss_type): def parse_loss(logs, loss_type):
_logs = map(lambda logline: logline.message, logs) _logs = map(lambda logline: logline.message, logs)
...@@ -73,28 +29,79 @@ def parse_loss(logs, loss_type): ...@@ -73,28 +29,79 @@ def parse_loss(logs, loss_type):
return np.array(list(epochs)), np.array(list(losses)) return np.array(list(epochs)), np.array(list(losses))
fig, ax = plt.subplots() def plot_loss(logfile, loss_types, ax, normalize=False):
for i, loss_type in enumerate(args.loss_types): logs = parse_file(logfile)
epochs, loss = parse_loss(logs, loss_type) logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
if args.normalize: for i, loss_type in enumerate(loss_types):
loss = loss / loss.max() epochs, loss = parse_loss(logs, loss_type)
label = loss_type[0].upper() + loss_type[1:] if normalize:
color = COLORS[i % len(COLORS)] loss = loss / loss.max()
ax.plot(epochs, loss, label=label, color=color)
ax.scatter(epochs, loss, s=15, color=color) label = loss_type[0].upper() + loss_type[1:]
color = COLORS[i % len(COLORS)]
ax.spines["left"].set_visible(False) ax.plot(epochs, loss, label=label, color=color)
ax.spines["right"].set_visible(False) ax.scatter(epochs, loss, s=15, color=color)
ax.spines["top"].set_visible(False)
ax.spines["left"].set_visible(False)
ax.grid(axis="y", alpha=0.7) ax.spines["right"].set_visible(False)
ax.legend() ax.spines["top"].set_visible(False)
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss") ax.grid(axis="y", alpha=0.7)
plt.tight_layout() ax.legend()
plt.savefig(args.out, dpi=300) ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
if args.verbose:
plt.show()
if __name__ == "__main__":
SIZE_DEFAULT = 12
plt.rc("font", family="Roboto") # controls default font
plt.rc("font", weight="normal") # controls default font
plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes
plt.rc("axes", titlesize=18) # fontsize of the axes title
parser = argparse.ArgumentParser(
description="plot the losses written to a logfile",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"logfile",
type=str,
help="the logfile from which the losses (training and validation) are parsed",
)
parser.add_argument(
"--loss_types",
type=str,
nargs="+",
default=["train", "validation"],
help="the types of losses the log is searched for",
)
parser.add_argument(
"--normalize",
action="store_true",
help="normalize the loss values (both training and validation losses are normalized separately)",
)
parser.add_argument(
"--out",
"-o",
type=str,
default="loss.png",
help="the file into which the resulting plot is saved",
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="do not only save the figure but also attempt to visualize it (opens a window)",
)
args = parser.parse_args()
fig, ax = plt.subplots()
plot_loss(args.logfile, args.loss_types, ax, normalize=args.normalize)
plt.tight_layout()
plt.savefig(args.out, dpi=300)
if args.verbose:
plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment