diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py
index dbc86189370b3d85402d8878b418587cbdae7ee1..c17bb846fb4089652292b29152b0dd49147090cb 100644
--- a/mu_map/vis/loss_curve.py
+++ b/mu_map/vis/loss_curve.py
@@ -29,13 +29,20 @@ def parse_loss(logs, loss_type):
     return np.array(list(epochs)), np.array(list(losses))
 
 
-def plot_loss(logfile, loss_types, ax, normalize=False):
+def plot_loss(logfile, loss_types, ax, normalize=False, from_epoch=0, to_epoch=None):
     logs = parse_file(logfile)
     logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
 
     for i, loss_type in enumerate(loss_types):
         epochs, loss = parse_loss(logs, loss_type)
 
+        if to_epoch == None:
+            to_epoch = max(epochs)
+
+        _filter = np.logical_and(from_epoch <= epochs, epochs <= to_epoch)
+        epochs = epochs[_filter]
+        loss = loss[_filter]
+
         if normalize:
             loss = loss / loss.max()
 
@@ -95,10 +102,28 @@ if __name__ == "__main__":
         action="store_true",
         help="do not only save the figure but also attempt to visualize it (opens a window)",
     )
+    parser.add_argument(
+        "--from_epoch",
+        type=int,
+        default=0,
+        help="start plotting from this epoch",
+    )
+    parser.add_argument(
+        "--to_epoch",
+        type=int,
+        help="only plot to this epoch",
+    )
     args = parser.parse_args()
 
     fig, ax = plt.subplots()
-    plot_loss(args.logfile, args.loss_types, ax, normalize=args.normalize)
+    plot_loss(
+        args.logfile,
+        args.loss_types,
+        ax,
+        normalize=args.normalize,
+        from_epoch=args.from_epoch,
+        to_epoch=args.to_epoch,
+    )
 
     plt.tight_layout()
     plt.savefig(args.out, dpi=300)