diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 21de59f985a6363e7c6c1073490c9ab9bf30849c..f578195465520da14ad42407e15d5cd326e7765f 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -194,7 +194,7 @@ class cGANTraining: outputs_d_fake = self.discriminator(inputs_d_fake) loss_g_adv = self.criterion_d(outputs_d_fake, labels_real) loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real) - loss_g = loss_g_adv + 100.0 * loss_g_l1 + loss_g = loss_g_adv + 50.0 * loss_g_l1 loss_g.backward() self.optimizer_g.step() @@ -252,7 +252,7 @@ if __name__ == "__main__": MaxNormTransform, GaussianNormTransform, ) - from mu_map.dataset.transform import ScaleTransform + from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet from mu_map.models.discriminator import Discriminator, PatchDiscriminator @@ -438,6 +438,7 @@ if __name__ == "__main__": transform_normalization = MaxNormTransform() elif args.input_norm == "gaussian": transform_normalization = GaussianNormTransform() + transform_normalization = SequenceTransform([transform_normalization, PadCropTranform(dim=3, size=32)]) data_loaders = {} for split in ["train", "validation"]: