diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py index 336b1ca949066b82662ff2628f64c883b2d973b8..8262b5e4d99018f1e92988765435eda08c2e173a 100644 --- a/mu_map/vis/loss_curve.py +++ b/mu_map/vis/loss_curve.py @@ -14,14 +14,38 @@ plt.rc("axes", titlesize=18) # fontsize of the axes title # https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk COLORS = ["#ef8a62", "#67a9cf"] -parser = argparse.ArgumentParser(description="TODO", formatter_class=argparse.ArgumentDefaultsHelpFormatter) -parser.add_argument("logfile", type=str, help="TODO") -parser.add_argument("--normalize", action="store_true", help="TODO") +parser = argparse.ArgumentParser( + description="plot the losses written to a logfile", formatter_class=argparse.ArgumentDefaultsHelpFormatter +) +parser.add_argument( + "logfile", + type=str, + help="the logfile from which the losses (training and validation) are parsed", +) +parser.add_argument( + "--normalize", + action="store_true", + help="normalize the loss values (both training and validation losses are normalized separately)", +) +parser.add_argument( + "--out", + "-o", + type=str, + default="loss.png", + help="the file into which the resulting plot is saved", +) +parser.add_argument( + "--verbose", + "-v", + action="store_true", + help="do not only save the figure but also attempt to visualize it (opens a window)", +) args = parser.parse_args() logs = parse_file(args.logfile) logs = list(filter(lambda logline: logline.loglevel == "INFO", logs)) + def parse_loss(logs, phase): _logs = map(lambda logline: logline.message, logs) _logs = filter(lambda log: phase in log, _logs) @@ -40,6 +64,7 @@ def parse_loss(logs, phase): return np.array(list(epochs)), np.array(list(losses)) + phases = ["TRAIN", "VAL"] labels = ["Training", "Validation"] @@ -62,5 +87,7 @@ ax.legend() ax.set_xlabel("Epoch") ax.set_ylabel("Loss") plt.tight_layout() -plt.show() +plt.savefig(args.out, dpi=300) +if args.verbose: + plt.show()