diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py index dbc86189370b3d85402d8878b418587cbdae7ee1..c17bb846fb4089652292b29152b0dd49147090cb 100644 --- a/mu_map/vis/loss_curve.py +++ b/mu_map/vis/loss_curve.py @@ -29,13 +29,20 @@ def parse_loss(logs, loss_type): return np.array(list(epochs)), np.array(list(losses)) -def plot_loss(logfile, loss_types, ax, normalize=False): +def plot_loss(logfile, loss_types, ax, normalize=False, from_epoch=0, to_epoch=None): logs = parse_file(logfile) logs = list(filter(lambda logline: logline.loglevel == "INFO", logs)) for i, loss_type in enumerate(loss_types): epochs, loss = parse_loss(logs, loss_type) + if to_epoch == None: + to_epoch = max(epochs) + + _filter = np.logical_and(from_epoch <= epochs, epochs <= to_epoch) + epochs = epochs[_filter] + loss = loss[_filter] + if normalize: loss = loss / loss.max() @@ -95,10 +102,28 @@ if __name__ == "__main__": action="store_true", help="do not only save the figure but also attempt to visualize it (opens a window)", ) + parser.add_argument( + "--from_epoch", + type=int, + default=0, + help="start plotting from this epoch", + ) + parser.add_argument( + "--to_epoch", + type=int, + help="only plot to this epoch", + ) args = parser.parse_args() fig, ax = plt.subplots() - plot_loss(args.logfile, args.loss_types, ax, normalize=args.normalize) + plot_loss( + args.logfile, + args.loss_types, + ax, + normalize=args.normalize, + from_epoch=args.from_epoch, + to_epoch=args.to_epoch, + ) plt.tight_layout() plt.savefig(args.out, dpi=300)