Skip to content
Snippets Groups Projects
loss_curve.py 2.02 KiB
import argparse

import matplotlib.pyplot as plt
import numpy as np

from mu_map.logging import parse_file

SIZE_DEFAULT = 12
plt.rc("font", family="Roboto")  # controls default font
plt.rc("font", weight="normal")  # controls default font
plt.rc("font", size=SIZE_DEFAULT)  # controls default text sizes
plt.rc("axes", titlesize=18)  # fontsize of the axes title

# https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk
COLORS = ["#ef8a62", "#67a9cf"]

parser = argparse.ArgumentParser(description="TODO", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("logfile", type=str, help="TODO")
parser.add_argument("--normalize", action="store_true", help="TODO")
args = parser.parse_args()

logs = parse_file(args.logfile)
logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))

def parse_loss(logs, phase):
    _logs = map(lambda logline: logline.message, logs)
    _logs = filter(lambda log: phase in log, _logs)
    _logs = filter(lambda log: "Loss" in log, _logs)
    _logs = list(_logs)

    losses = map(lambda log: log.split("-")[-1].strip(), _logs)
    losses = map(lambda log: log.split(":")[-1].strip(), losses)
    losses = map(float, losses)

    epochs = map(lambda log: log.split("-")[0].strip(), _logs)
    epochs = list(epochs)
    epochs = map(lambda log: log.split(" ")[-1], epochs)
    epochs = map(lambda log: log.split("/")[0], epochs)
    epochs = map(int, epochs)

    return np.array(list(epochs)), np.array(list(losses))

phases = ["TRAIN", "VAL"]
labels = ["Training", "Validation"]

fig, ax = plt.subplots()
for phase, label, color in zip(phases, labels, COLORS):
    epochs, loss = parse_loss(logs, phase)

    if args.normalize:
        loss = loss / loss.max()

    ax.plot(epochs, loss, label=label, color=color)
    ax.scatter(epochs, loss, s=15, color=color)

ax.spines["left"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)

ax.grid(axis="y", alpha=0.7)
ax.legend()
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
plt.tight_layout()
plt.show()