From 4b3bf47271b4f3f3c646e0063f0d8aa90eef4daa Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Mon, 10 Oct 2022 14:03:57 +0200 Subject: [PATCH] add parameters to splot specific epochs --- mu_map/vis/loss_curve.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py index dbc8618..c17bb84 100644 --- a/mu_map/vis/loss_curve.py +++ b/mu_map/vis/loss_curve.py @@ -29,13 +29,20 @@ def parse_loss(logs, loss_type): return np.array(list(epochs)), np.array(list(losses)) -def plot_loss(logfile, loss_types, ax, normalize=False): +def plot_loss(logfile, loss_types, ax, normalize=False, from_epoch=0, to_epoch=None): logs = parse_file(logfile) logs = list(filter(lambda logline: logline.loglevel == "INFO", logs)) for i, loss_type in enumerate(loss_types): epochs, loss = parse_loss(logs, loss_type) + if to_epoch == None: + to_epoch = max(epochs) + + _filter = np.logical_and(from_epoch <= epochs, epochs <= to_epoch) + epochs = epochs[_filter] + loss = loss[_filter] + if normalize: loss = loss / loss.max() @@ -95,10 +102,28 @@ if __name__ == "__main__": action="store_true", help="do not only save the figure but also attempt to visualize it (opens a window)", ) + parser.add_argument( + "--from_epoch", + type=int, + default=0, + help="start plotting from this epoch", + ) + parser.add_argument( + "--to_epoch", + type=int, + help="only plot to this epoch", + ) args = parser.parse_args() fig, ax = plt.subplots() - plot_loss(args.logfile, args.loss_types, ax, normalize=args.normalize) + plot_loss( + args.logfile, + args.loss_types, + ax, + normalize=args.normalize, + from_epoch=args.from_epoch, + to_epoch=args.to_epoch, + ) plt.tight_layout() plt.savefig(args.out, dpi=300) -- GitLab