import argparse import matplotlib.pyplot as plt import numpy as np from mu_map.logging import parse_file SIZE_DEFAULT = 12 plt.rc("font", family="Roboto") # controls default font plt.rc("font", weight="normal") # controls default font plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes plt.rc("axes", titlesize=18) # fontsize of the axes title # https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk COLORS = ["#ef8a62", "#67a9cf"] parser = argparse.ArgumentParser(description="TODO", formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument("logfile", type=str, help="TODO") parser.add_argument("--normalize", action="store_true", help="TODO") args = parser.parse_args() logs = parse_file(args.logfile) logs = list(filter(lambda logline: logline.loglevel == "INFO", logs)) def parse_loss(logs, phase): _logs = map(lambda logline: logline.message, logs) _logs = filter(lambda log: phase in log, _logs) _logs = filter(lambda log: "Loss" in log, _logs) _logs = list(_logs) losses = map(lambda log: log.split("-")[-1].strip(), _logs) losses = map(lambda log: log.split(":")[-1].strip(), losses) losses = map(float, losses) epochs = map(lambda log: log.split("-")[0].strip(), _logs) epochs = list(epochs) epochs = map(lambda log: log.split(" ")[-1], epochs) epochs = map(lambda log: log.split("/")[0], epochs) epochs = map(int, epochs) return np.array(list(epochs)), np.array(list(losses)) phases = ["TRAIN", "VAL"] labels = ["Training", "Validation"] fig, ax = plt.subplots() for phase, label, color in zip(phases, labels, COLORS): epochs, loss = parse_loss(logs, phase) if args.normalize: loss = loss / loss.max() ax.plot(epochs, loss, label=label, color=color) ax.scatter(epochs, loss, s=15, color=color) ax.spines["left"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["top"].set_visible(False) ax.grid(axis="y", alpha=0.7) ax.legend() ax.set_xlabel("Epoch") ax.set_ylabel("Loss") plt.tight_layout() plt.show()