From abe29b30ed49b65139df42ce5a484670a69d03e9 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 13 Oct 2022 15:03:42 +0200
Subject: [PATCH] use padcrop transform in cgan training

---
 mu_map/training/cgan.py | 5 +++--
 1 file changed, 3 insertions(+), 2 deletions(-)

diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index 21de59f..f578195 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -194,7 +194,7 @@ class cGANTraining:
             outputs_d_fake = self.discriminator(inputs_d_fake)
             loss_g_adv = self.criterion_d(outputs_d_fake, labels_real)
             loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real)
-            loss_g = loss_g_adv + 100.0 * loss_g_l1
+            loss_g = loss_g_adv + 50.0 * loss_g_l1
             loss_g.backward()
             self.optimizer_g.step()
 
@@ -252,7 +252,7 @@ if __name__ == "__main__":
         MaxNormTransform,
         GaussianNormTransform,
     )
-    from mu_map.dataset.transform import ScaleTransform
+    from mu_map.dataset.transform import PadCropTranform, SequenceTransform
     from mu_map.logging import add_logging_args, get_logger_by_args
     from mu_map.models.unet import UNet
     from mu_map.models.discriminator import Discriminator, PatchDiscriminator
@@ -438,6 +438,7 @@ if __name__ == "__main__":
         transform_normalization = MaxNormTransform()
     elif args.input_norm == "gaussian":
         transform_normalization = GaussianNormTransform()
+    transform_normalization = SequenceTransform([transform_normalization, PadCropTranform(dim=3, size=32)])
 
     data_loaders = {}
     for split in ["train", "validation"]:
-- 
GitLab