diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py index 8262b5e4d99018f1e92988765435eda08c2e173a..71585f315c3d731e3d0ea521da938a358b401880 100644 --- a/mu_map/vis/loss_curve.py +++ b/mu_map/vis/loss_curve.py @@ -15,13 +15,21 @@ plt.rc("axes", titlesize=18) # fontsize of the axes title COLORS = ["#ef8a62", "#67a9cf"] parser = argparse.ArgumentParser( - description="plot the losses written to a logfile", formatter_class=argparse.ArgumentDefaultsHelpFormatter + description="plot the losses written to a logfile", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "logfile", type=str, help="the logfile from which the losses (training and validation) are parsed", ) +parser.add_argument( + "--loss_types", + type=str, + nargs="+", + default=["train", "validation"], + help="the types of losses the log is searched for", +) parser.add_argument( "--normalize", action="store_true", @@ -46,9 +54,9 @@ logs = parse_file(args.logfile) logs = list(filter(lambda logline: logline.loglevel == "INFO", logs)) -def parse_loss(logs, phase): +def parse_loss(logs, loss_type): _logs = map(lambda logline: logline.message, logs) - _logs = filter(lambda log: phase in log, _logs) + _logs = filter(lambda log: loss_type in log, _logs) _logs = filter(lambda log: "Loss" in log, _logs) _logs = list(_logs) @@ -65,16 +73,15 @@ def parse_loss(logs, phase): return np.array(list(epochs)), np.array(list(losses)) -phases = ["TRAIN", "VAL"] -labels = ["Training", "Validation"] - fig, ax = plt.subplots() -for phase, label, color in zip(phases, labels, COLORS): - epochs, loss = parse_loss(logs, phase) +for i, loss_type in enumerate(args.loss_types): + epochs, loss = parse_loss(logs, loss_type) if args.normalize: loss = loss / loss.max() + label = loss_type[0].upper() + loss_type[1:] + color = COLORS[i % len(COLORS)] ax.plot(epochs, loss, label=label, color=color) ax.scatter(epochs, loss, s=15, color=color)