Newer
Older
import argparse
import matplotlib.pyplot as plt
import numpy as np
from mu_map.logging import parse_file
# https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk
COLORS = ["#ef8a62", "#67a9cf"]
def parse_loss(logs, loss_type):
_logs = map(lambda logline: logline.message, logs)
_logs = filter(lambda log: loss_type in log, _logs)
_logs = filter(lambda log: "Loss" in log, _logs)
_logs = list(_logs)
losses = map(lambda log: log.split("-")[-1].strip(), _logs)
losses = map(lambda log: log.split(":")[-1].strip(), losses)
losses = map(float, losses)
epochs = map(lambda log: log.split("-")[0].strip(), _logs)
epochs = list(epochs)
epochs = map(lambda log: log.split(" ")[-1], epochs)
epochs = map(lambda log: log.split("/")[0], epochs)
epochs = map(int, epochs)
return np.array(list(epochs)), np.array(list(losses))
def plot_loss(logfile, loss_types, ax, normalize=False, from_epoch=0, to_epoch=None, start_idx_message=3):
logs = parse_file(logfile, start_idx_message=start_idx_message)
logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
for i, loss_type in enumerate(loss_types):
epochs, loss = parse_loss(logs, loss_type)
if to_epoch == None:
to_epoch = max(epochs)
_filter = np.logical_and(from_epoch <= epochs, epochs <= to_epoch)
epochs = epochs[_filter]
loss = loss[_filter]
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
if normalize:
loss = loss / loss.max()
label = loss_type[0].upper() + loss_type[1:]
color = COLORS[i % len(COLORS)]
ax.plot(epochs, loss, label=label, color=color)
ax.scatter(epochs, loss, s=15, color=color)
ax.spines["left"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.grid(axis="y", alpha=0.7)
ax.legend()
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
if __name__ == "__main__":
SIZE_DEFAULT = 12
plt.rc("font", family="Roboto") # controls default font
plt.rc("font", weight="normal") # controls default font
plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes
plt.rc("axes", titlesize=18) # fontsize of the axes title
parser = argparse.ArgumentParser(
description="plot the losses written to a logfile",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"logfile",
type=str,
help="the logfile from which the losses (training and validation) are parsed",
)
parser.add_argument(
"--loss_types",
type=str,
nargs="+",
default=["train", "validation"],
help="the types of losses the log is searched for",
)
parser.add_argument(
"--normalize",
action="store_true",
help="normalize the loss values (both training and validation losses are normalized separately)",
)
parser.add_argument(
"--out",
"-o",
type=str,
default="loss.png",
help="the file into which the resulting plot is saved",
)
parser.add_argument(
"--verbose",
"-v",
action="store_true",
help="do not only save the figure but also attempt to visualize it (opens a window)",
)
parser.add_argument(
"--from_epoch",
type=int,
default=0,
help="start plotting from this epoch",
)
parser.add_argument(
"--to_epoch",
type=int,
help="only plot to this epoch",
)
args = parser.parse_args()
fig, ax = plt.subplots()
plot_loss(
args.logfile,
args.loss_types,
ax,
normalize=args.normalize,
from_epoch=args.from_epoch,
to_epoch=args.to_epoch,
)
plt.tight_layout()
plt.savefig(args.out, dpi=300)
if args.verbose:
plt.show()