From b08f5d081d04fb984437c0e9eaad918b8c74878d Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 13 Jan 2023 10:15:29 +0100
Subject: [PATCH] add early stopping param to cgan training and update doc

---
 mu_map/training/cgan.py | 59 ++++++++++++++++++++++++++++++++---------
 1 file changed, 46 insertions(+), 13 deletions(-)

diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index 8880c46..d5bad7d 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -17,6 +17,7 @@ class DiscriminatorParams(TrainingParams):
     """
     Wrap training parameters to always carry the name 'Discriminator'.
     """
+
     def __init__(
         self,
         model: torch.nn.Module,
@@ -30,10 +31,12 @@ class DiscriminatorParams(TrainingParams):
             lr_scheduler=lr_scheduler,
         )
 
+
 class GeneratorParams(TrainingParams):
     """
     Wrap training parameters to always carry the name 'Generator'.
     """
+
     def __init__(
         self,
         model: torch.nn.Module,
@@ -48,11 +51,26 @@ class GeneratorParams(TrainingParams):
         )
 
 
-
 class cGANTraining(AbstractTraining):
     """
     Implementation of a conditional generative adversarial network training.
+
+    To see all parameters, have a look at AbstractTraining.
+
+    Parameters
+    ----------
+    params_generator: GeneratorParams
+        training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
+    params_discriminator: DiscriminatorParams
+        training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
+    loss_func_dist: WeightedLoss
+        distance loss function for the generator
+    weight_criterion_dist: float
+        weight of the distance loss when training the generator
+    weight_criterion_adv: float
+        weight of the adversarial loss when training the generator
     """
+
     def __init__(
         self,
         epochs: int,
@@ -66,17 +84,18 @@ class cGANTraining(AbstractTraining):
         loss_func_dist: WeightedLoss,
         weight_criterion_dist: float,
         weight_criterion_adv: float,
+        early_stopping: Optional[int],
         logger: Optional[Logger] = None,
     ):
-        """
-        :param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
-        :param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
-        :param loss_func_dist: distance loss function for the generator
-        :param weight_criterion_dist: weight of the distance loss when training the generator
-        :param weight_criterion_adv: weight of the adversarial loss when training the generator
-        """
         super().__init__(
-            epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger
+            epochs=epochs,
+            dataset=dataset,
+            batch_size=batch_size,
+            device=device,
+            early_stopping=early_stopping,
+            snapshot_dir=snapshot_dir,
+            snapshot_epoch=snapshot_epoch,
+            logger=logger,
         )
         self.training_params.append(params_generator)
         self.training_params.append(params_discriminator)
@@ -102,7 +121,7 @@ class cGANTraining(AbstractTraining):
         pass
 
     def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
-        mu_maps_real = mu_maps # rename real mu maps for clarification
+        mu_maps_real = mu_maps  # rename real mu maps for clarification
         # compute fake mu maps with generator
         mu_maps_fake = self.generator(recons)
 
@@ -111,8 +130,16 @@ class cGANTraining(AbstractTraining):
         inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
 
         # prepare labels/targets for the discriminator
-        labels_fake = torch.full(self.discriminator.get_output_shape(inputs_d_fake.shape), LABEL_FAKE, device=self.device)
-        labels_real = torch.full(self.discriminator.get_output_shape(inputs_d_real.shape), LABEL_REAL, device=self.device)
+        labels_fake = torch.full(
+            self.discriminator.get_output_shape(inputs_d_fake.shape),
+            LABEL_FAKE,
+            device=self.device,
+        )
+        labels_real = torch.full(
+            self.discriminator.get_output_shape(inputs_d_real.shape),
+            LABEL_REAL,
+            device=self.device,
+        )
 
         # ======================= Discriminator =====================================
         # compute discriminator loss for fake mu maps
@@ -131,7 +158,7 @@ class cGANTraining(AbstractTraining):
         # ===========================================================================
 
         # ======================= Generator =========================================
-        outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach
+        outputs_d_fake = self.discriminator(inputs_d_fake)  # this time no detach
         loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
         loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
         loss_g = (
@@ -256,6 +283,11 @@ if __name__ == "__main__":
         default="cuda:0" if torch.cuda.is_available() else "cpu",
         help="the device (cpu or gpu) with which the training is performed",
     )
+    parser.add_argument(
+        "--early_stopping",
+        type=int,
+        help="define early stopping as the least amount of epochs in which the validation loss must improve",
+    )
     parser.add_argument(
         "--dist_loss_func",
         type=str,
@@ -400,6 +432,7 @@ if __name__ == "__main__":
         dataset=dataset,
         batch_size=args.batch_size,
         device=device,
+        early_stopping=args.early_stopping,
         snapshot_dir=args.snapshot_dir,
         snapshot_epoch=args.snapshot_epoch,
         params_generator=params_g,
-- 
GitLab