diff --git a/mu_map/eval/measures.py b/mu_map/eval/measures.py
index f91c3e6e7cd450320af390b87729bf36b14e97c1..b128bf58e0b125a5112ddb4f426e874d96d15b74 100644
--- a/mu_map/eval/measures.py
+++ b/mu_map/eval/measures.py
@@ -42,6 +42,7 @@ if __name__ == "__main__":
         help="the model weights which should be scored",
     )
     parser.add_argument("--out", type=str, help="write results as a csv file")
+    parser.add_argument("--scatter_corrected", action="store_true")
 
     parser.add_argument(
         "--dataset_dir",
@@ -88,9 +89,10 @@ if __name__ == "__main__":
         ]
     )
     dataset = MuMapDataset(
-        "data/initial/",
+        args.dataset_dir,
         transform_normalization=transform_normalization,
         split_name=args.split,
+        scatter_correction=args.scatter_corrected,
     )
 
     measures = {"NMAE": nmae, "MSE": mse}
diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index f578195465520da14ad42407e15d5cd326e7765f..2668175e27d1fa2a5de169417bc7fb48feed0e84 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -194,7 +194,7 @@ class cGANTraining:
             outputs_d_fake = self.discriminator(inputs_d_fake)
             loss_g_adv = self.criterion_d(outputs_d_fake, labels_real)
             loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real)
-            loss_g = loss_g_adv + 50.0 * loss_g_l1
+            loss_g = loss_g_adv + 20.0 * loss_g_l1
             loss_g.backward()
             self.optimizer_g.step()
 
@@ -308,6 +308,11 @@ if __name__ == "__main__":
         action="store_true",
         help="do not shuffle patches in the dataset",
     )
+    parser.add_argument(
+        "scatter_correction",
+        action="store_true",
+        help="use the scatter corrected reconstructions in the dataset",
+    )
 
     # Training Args
     parser.add_argument(
@@ -450,6 +455,7 @@ if __name__ == "__main__":
             shuffle=not args.no_shuffle,
             split_name=split,
             transform_normalization=transform_normalization,
+            scatter_correction=args.scatter_correction,
             logger=logger,
         )
         data_loader = torch.utils.data.DataLoader(