-
Tamino Huxohl authoredTamino Huxohl authored
loss_curve.py 2.02 KiB
import argparse
import matplotlib.pyplot as plt
import numpy as np
from mu_map.logging import parse_file
SIZE_DEFAULT = 12
plt.rc("font", family="Roboto") # controls default font
plt.rc("font", weight="normal") # controls default font
plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes
plt.rc("axes", titlesize=18) # fontsize of the axes title
# https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk
COLORS = ["#ef8a62", "#67a9cf"]
parser = argparse.ArgumentParser(description="TODO", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("logfile", type=str, help="TODO")
parser.add_argument("--normalize", action="store_true", help="TODO")
args = parser.parse_args()
logs = parse_file(args.logfile)
logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
def parse_loss(logs, phase):
_logs = map(lambda logline: logline.message, logs)
_logs = filter(lambda log: phase in log, _logs)
_logs = filter(lambda log: "Loss" in log, _logs)
_logs = list(_logs)
losses = map(lambda log: log.split("-")[-1].strip(), _logs)
losses = map(lambda log: log.split(":")[-1].strip(), losses)
losses = map(float, losses)
epochs = map(lambda log: log.split("-")[0].strip(), _logs)
epochs = list(epochs)
epochs = map(lambda log: log.split(" ")[-1], epochs)
epochs = map(lambda log: log.split("/")[0], epochs)
epochs = map(int, epochs)
return np.array(list(epochs)), np.array(list(losses))
phases = ["TRAIN", "VAL"]
labels = ["Training", "Validation"]
fig, ax = plt.subplots()
for phase, label, color in zip(phases, labels, COLORS):
epochs, loss = parse_loss(logs, phase)
if args.normalize:
loss = loss / loss.max()
ax.plot(epochs, loss, label=label, color=color)
ax.scatter(epochs, loss, s=15, color=color)
ax.spines["left"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.grid(axis="y", alpha=0.7)
ax.legend()
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.tight_layout()
plt.show()