From afdee20bd29ef751cf69d03200a75aabc6b23d96 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Mon, 10 Oct 2022 11:12:38 +0200
Subject: [PATCH] make loss weights in cgan loss configurable

---
 mu_map/training/cgan.py | 49 ++++++++++++++++++++++++++++++++++-------
 1 file changed, 41 insertions(+), 8 deletions(-)

diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index 1716e26..25d9fc2 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -14,7 +14,11 @@ LABEL_FAKE = 0.0
 
 class GeneratorLoss(torch.nn.Module):
     def __init__(
-        self, l2_weight: float = 1.0, gdl_weight: float = 1.0, adv_weight: float = 20.0
+        self,
+        l2_weight: float = 1.0,
+        gdl_weight: float = 1.0,
+        adv_weight: float = 20.0,
+        logger=None,
     ):
         super().__init__()
 
@@ -27,6 +31,12 @@ class GeneratorLoss(torch.nn.Module):
         self.adv = torch.nn.MSELoss(reduction="mean")
         self.adv_weight = adv_weight
 
+        if logger:
+            logger.debug(f"GeneratorLoss: {self}")
+
+    def __repr__(self):
+        return f"{self.l2_weight:.3f} * MSELoss + {self.gdl_weight:.3f} * GDLLoss + {self.adv_weight:.3f} * AdversarialLoss"
+
     def forward(
         self,
         mu_maps_real: Tensor,
@@ -94,7 +104,10 @@ class cGANTraining:
 
         self.criterion_d = torch.nn.MSELoss(reduction="mean")
         self.criterion_g = GeneratorLoss(
-            l2_weight=l2_weight, gdl_weight=gdl_weight, adv_weight=adv_weight
+            l2_weight=l2_weight,
+            gdl_weight=gdl_weight,
+            adv_weight=adv_weight,
+            logger=self.logger,
         )
 
     def run(self):
@@ -174,7 +187,9 @@ class cGANTraining:
 
             # update generator
             outputs_d = self.discriminator(mu_maps_fake)
-            loss_g = self.criterion_g(mu_maps_real, mu_maps_fake, labels_real, outputs_d)
+            loss_g = self.criterion_g(
+                mu_maps_real, mu_maps_fake, labels_real, outputs_d
+            )
             loss_g.backward()
             self.optimizer_g.step()
 
@@ -189,7 +204,7 @@ class cGANTraining:
 
         loss = 0.0
         updates = 0
-        
+
         data_loader = self.data_loaders[split_name]
         for i, (recons, mu_maps) in enumerate(data_loader):
             print(
@@ -319,6 +334,24 @@ if __name__ == "__main__":
         default="cuda:0" if torch.cuda.is_available() else "cpu",
         help="the device (cpu or gpu) with which the training is performed",
     )
+    parser.add_argument(
+        "--mse_loss_weight",
+        type=float,
+        default=1.0,
+        help="weight for the L2-Loss of the generator",
+    )
+    parser.add_argument(
+        "--gdl_loss_weight",
+        type=float,
+        default=1.0,
+        help="weight for the Gradient-Difference-Loss of the generator",
+    )
+    parser.add_argument(
+        "--adv_loss_weight",
+        type=float,
+        default=20.0,
+        help="weight for the Adversarial-Loss of the generator",
+    )
     parser.add_argument(
         "--lr", type=float, default=0.001, help="the initial learning rate for training"
     )
@@ -372,7 +405,7 @@ if __name__ == "__main__":
     logger = get_logger_by_args(args)
     logger.info(args)
 
-    args.seed = args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1)
+    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
     logger.info(f"Seed: {args.seed}")
     random.seed(args.seed)
     torch.manual_seed(args.seed)
@@ -425,9 +458,9 @@ if __name__ == "__main__":
         lr_g=0.001,
         lr_decay_factor_g=0.99,
         lr_decay_epoch_g=1,
-        l2_weight=0.25,
-        gdl_weight=0.25,
-        adv_weight=0.5,
+        l2_weight=args.mse_loss_weight,
+        gdl_weight=args.gdl_loss_weight,
+        adv_weight=args.adv_loss_weight,
         snapshot_dir=args.snapshot_dir,
         snapshot_epoch=args.snapshot_epoch,
         logger=logger,
-- 
GitLab