From 5d9d695e6de9688bb518cbb6f3cefffa791aa351 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 13 Jan 2023 10:18:47 +0100
Subject: [PATCH] update distance training to use early stopping

---
 mu_map/training/distance.py | 37 ++++++++++++++++++++++++++++++-------
 1 file changed, 30 insertions(+), 7 deletions(-)

diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py
index 78c94a6..5e29323 100644
--- a/mu_map/training/distance.py
+++ b/mu_map/training/distance.py
@@ -12,7 +12,17 @@ class DistanceTraining(AbstractTraining):
     """
     Implementation of a distance training: a model predicts a mu map
     from a reconstruction by optimizing a distance loss (e.g. L1).
+
+    To see all parameters, have a look at AbstractTraining.
+
+    Parameters
+    ----------
+    params: TrainingParams
+        training parameters containing a model an according optimizer and optionally a learning rate scheduler
+    loss_func: WeightedLoss
+        the distance loss function
     """
+
     def __init__(
         self,
         epochs: int,
@@ -23,13 +33,19 @@ class DistanceTraining(AbstractTraining):
         snapshot_epoch: int,
         params: TrainingParams,
         loss_func: WeightedLoss,
+        early_stopping: Optional[int] = None,
         logger: Optional[Logger] = None,
     ):
-        """
-        :param params: training parameters containing a model an according optimizer and optionally a learning rate scheduler
-        :param loss_func: the distance loss function
-        """
-        super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger)
+        super().__init__(
+            epochs=epochs,
+            dataset=dataset,
+            batch_size=batch_size,
+            device=device,
+            snapshot_dir=snapshot_dir,
+            snapshot_epoch=snapshot_epoch,
+            early_stopping=early_stopping,
+            logger=logger,
+        )
         self.training_params.append(params)
 
         self.loss_func = loss_func
@@ -140,6 +156,11 @@ if __name__ == "__main__":
         default=100,
         help="the number of epochs for which the model is trained",
     )
+    parser.add_argument(
+        "--early_stopping",
+        type=int,
+        help="define early stopping as the least amount of epochs in which the validation loss must improve",
+    )
     parser.add_argument(
         "--device",
         type=str,
@@ -216,7 +237,6 @@ if __name__ == "__main__":
     torch.manual_seed(args.seed)
     np.random.seed(args.seed)
 
-
     transform_normalization = None
     if args.input_norm == "mean":
         transform_normalization = MeanNormTransform()
@@ -244,7 +264,9 @@ if __name__ == "__main__":
         if args.decay_lr
         else None
     )
-    params = TrainingParams(name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)
+    params = TrainingParams(
+        name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler
+    )
 
     criterion = WeightedLoss.from_str(args.loss_func)
 
@@ -257,6 +279,7 @@ if __name__ == "__main__":
         snapshot_epoch=args.snapshot_epoch,
         params=params,
         loss_func=criterion,
+        early_stopping=args.early_stopping,
         logger=logger,
     )
     training.run()
-- 
GitLab