From abe29b30ed49b65139df42ce5a484670a69d03e9 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Thu, 13 Oct 2022 15:03:42 +0200 Subject: [PATCH] use padcrop transform in cgan training --- mu_map/training/cgan.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 21de59f..f578195 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -194,7 +194,7 @@ class cGANTraining: outputs_d_fake = self.discriminator(inputs_d_fake) loss_g_adv = self.criterion_d(outputs_d_fake, labels_real) loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real) - loss_g = loss_g_adv + 100.0 * loss_g_l1 + loss_g = loss_g_adv + 50.0 * loss_g_l1 loss_g.backward() self.optimizer_g.step() @@ -252,7 +252,7 @@ if __name__ == "__main__": MaxNormTransform, GaussianNormTransform, ) - from mu_map.dataset.transform import ScaleTransform + from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet from mu_map.models.discriminator import Discriminator, PatchDiscriminator @@ -438,6 +438,7 @@ if __name__ == "__main__": transform_normalization = MaxNormTransform() elif args.input_norm == "gaussian": transform_normalization = GaussianNormTransform() + transform_normalization = SequenceTransform([transform_normalization, PadCropTranform(dim=3, size=32)]) data_loaders = {} for split in ["train", "validation"]: -- GitLab