diff --git a/mu_map/eval/measures.py b/mu_map/eval/measures.py index f91c3e6e7cd450320af390b87729bf36b14e97c1..b128bf58e0b125a5112ddb4f426e874d96d15b74 100644 --- a/mu_map/eval/measures.py +++ b/mu_map/eval/measures.py @@ -42,6 +42,7 @@ if __name__ == "__main__": help="the model weights which should be scored", ) parser.add_argument("--out", type=str, help="write results as a csv file") + parser.add_argument("--scatter_corrected", action="store_true") parser.add_argument( "--dataset_dir", @@ -88,9 +89,10 @@ if __name__ == "__main__": ] ) dataset = MuMapDataset( - "data/initial/", + args.dataset_dir, transform_normalization=transform_normalization, split_name=args.split, + scatter_correction=args.scatter_corrected, ) measures = {"NMAE": nmae, "MSE": mse} diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index f578195465520da14ad42407e15d5cd326e7765f..2668175e27d1fa2a5de169417bc7fb48feed0e84 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -194,7 +194,7 @@ class cGANTraining: outputs_d_fake = self.discriminator(inputs_d_fake) loss_g_adv = self.criterion_d(outputs_d_fake, labels_real) loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real) - loss_g = loss_g_adv + 50.0 * loss_g_l1 + loss_g = loss_g_adv + 20.0 * loss_g_l1 loss_g.backward() self.optimizer_g.step() @@ -308,6 +308,11 @@ if __name__ == "__main__": action="store_true", help="do not shuffle patches in the dataset", ) + parser.add_argument( + "scatter_correction", + action="store_true", + help="use the scatter corrected reconstructions in the dataset", + ) # Training Args parser.add_argument( @@ -450,6 +455,7 @@ if __name__ == "__main__": shuffle=not args.no_shuffle, split_name=split, transform_normalization=transform_normalization, + scatter_correction=args.scatter_correction, logger=logger, ) data_loader = torch.utils.data.DataLoader(