From 4b3bf47271b4f3f3c646e0063f0d8aa90eef4daa Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Mon, 10 Oct 2022 14:03:57 +0200
Subject: [PATCH] add parameters to splot specific epochs

---
 mu_map/vis/loss_curve.py | 29 +++++++++++++++++++++++++++--
 1 file changed, 27 insertions(+), 2 deletions(-)

diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py
index dbc8618..c17bb84 100644
--- a/mu_map/vis/loss_curve.py
+++ b/mu_map/vis/loss_curve.py
@@ -29,13 +29,20 @@ def parse_loss(logs, loss_type):
     return np.array(list(epochs)), np.array(list(losses))
 
 
-def plot_loss(logfile, loss_types, ax, normalize=False):
+def plot_loss(logfile, loss_types, ax, normalize=False, from_epoch=0, to_epoch=None):
     logs = parse_file(logfile)
     logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
 
     for i, loss_type in enumerate(loss_types):
         epochs, loss = parse_loss(logs, loss_type)
 
+        if to_epoch == None:
+            to_epoch = max(epochs)
+
+        _filter = np.logical_and(from_epoch <= epochs, epochs <= to_epoch)
+        epochs = epochs[_filter]
+        loss = loss[_filter]
+
         if normalize:
             loss = loss / loss.max()
 
@@ -95,10 +102,28 @@ if __name__ == "__main__":
         action="store_true",
         help="do not only save the figure but also attempt to visualize it (opens a window)",
     )
+    parser.add_argument(
+        "--from_epoch",
+        type=int,
+        default=0,
+        help="start plotting from this epoch",
+    )
+    parser.add_argument(
+        "--to_epoch",
+        type=int,
+        help="only plot to this epoch",
+    )
     args = parser.parse_args()
 
     fig, ax = plt.subplots()
-    plot_loss(args.logfile, args.loss_types, ax, normalize=args.normalize)
+    plot_loss(
+        args.logfile,
+        args.loss_types,
+        ax,
+        normalize=args.normalize,
+        from_epoch=args.from_epoch,
+        to_epoch=args.to_epoch,
+    )
 
     plt.tight_layout()
     plt.savefig(args.out, dpi=300)
-- 
GitLab