diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index 05de5caee053aa089ded8f1551e3b8c146814b23..289fb248210ba8894f05cfb7c3910e0eac3e5730 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -60,7 +60,7 @@ class Discriminator(nn.Module):
             nn.Linear(in_features=512, out_features=128),
             nn.ReLU(inplace=True),
             nn.Linear(in_features=128, out_features=1),
-            nn.Sigmoid(),
+            # nn.Sigmoid(),
         )
 
     def forward(self, x: torch.Tensor):
diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index 6352b21c27efe0bf8af81a78cb108aec329ca6fb..3aa83ec2d64990ffe46b7e5593402517648b1cff 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -15,14 +15,15 @@ LABEL_FAKE = 0.0
 class GeneratorLoss(torch.nn.Module):
     def __init__(
         self,
-        l2_weight: float = 1.0,
-        gdl_weight: float = 1.0,
-        adv_weight: float = 20.0,
-        logger=None,
+        # l2_weight: float = 1.0,
+        # gdl_weight: float = 1.0,
+        # adv_weight: float = 20.0,
+        # logger=None,
     ):
         super().__init__()
 
-        self.l2 = torch.nn.MSELoss(reduction="mean")
+        # self.l2 = torch.nn.MSELoss(reduction="mean")
+        self.l2 = torch.nn.L1Loss(reduction="mean")
         self.l2_weight = l2_weight
 
         self.gdl = GradientDifferenceLoss()
@@ -88,19 +89,19 @@ class cGANTraining:
 
         self.logger = logger if logger is not None else get_logger()
 
-        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)
-        self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g)
+        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))
+        self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
 
-        self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR(
-            self.optimizer_d,
-            step_size=lr_decay_epoch_d,
-            gamma=lr_decay_factor_d,
-        )
-        self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR(
-            self.optimizer_g,
-            step_size=lr_decay_epoch_g,
-            gamma=lr_decay_factor_g,
-        )
+        # self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR(
+            # self.optimizer_d,
+            # step_size=lr_decay_epoch_d,
+            # gamma=lr_decay_factor_d,
+        # )
+        # self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR(
+            # self.optimizer_g,
+            # step_size=lr_decay_epoch_g,
+            # gamma=lr_decay_factor_g,
+        # )
 
         self.criterion_d = torch.nn.MSELoss(reduction="mean")
         self.criterion_g = GeneratorLoss(
@@ -109,6 +110,7 @@ class cGANTraining:
             adv_weight=adv_weight,
             logger=self.logger,
         )
+        self.criterion_l1 = torch.nn.L1Loss(reduction="mean")
 
     def run(self):
         losses_d = []
@@ -124,8 +126,8 @@ class cGANTraining:
             self._eval_epoch(epoch, "train")
             self._eval_epoch(epoch, "validation")
 
-            self.lr_scheduler_d.step()
-            self.lr_scheduler_g.step()
+            # self.lr_scheduler_d.step()
+            # self.lr_scheduler_g.step()
 
             if epoch % self.snapshot_epoch == 0:
                 self.store_snapshot(epoch)
@@ -162,34 +164,37 @@ class cGANTraining:
 
     def _step(self, recons, mu_maps_real):
         batch_size = recons.shape[0]
-
-        self.optimizer_d.zero_grad()
-        self.optimizer_g.zero_grad()
-
         labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device)
         labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device)
 
         with torch.set_grad_enabled(True):
+            self.optimizer_d.zero_grad()
+            self.optimizer_g.zero_grad()
+
             # compute fake mu maps with generator
             mu_maps_fake = self.generator(recons)
 
-            # update discriminator based on real mu maps
-            outputs_d = self.discriminator(mu_maps_real)
-            loss_d_real = self.criterion_d(outputs_d, labels_real)
-            loss_d_real.backward()  # compute gradients
-            # update discriminator based on fake mu maps
-            outputs_d = self.discriminator(
-                mu_maps_fake.detach()
-            )  # note the detach, so that gradients are not computed for the generator
-            loss_d_fake = self.criterion_d(outputs_d, labels_fake)
-            loss_d_fake.backward()  # compute gradients
-            self.optimizer_d.step()  # update discriminator based on gradients
+            # compute discriminator loss for fake mu maps
+            inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
+            outputs_d_fake = self.discriminator(inputs_d_fake.detach())  # note the detach, so that gradients are not computed for the generator
+            loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake)
+
+            # compute discriminator loss for real mu maps
+            inputs_d_real = torch.cat((recons, mu_maps), dim=1)
+            outputs_d_real = self.discriminator(inputs_d_real)  # note the detach, so that gradients are not computed for the generator
+            loss_d_real = self.criterion_d(outputs_d_real, labels_real)
+
+            # update discriminator
+            loss_d = 0.5 * (loss_d_fake + loss_d_real)
+            loss_d.backward()  # compute gradients
+            self.optimizer_d.step()
 
             # update generator
-            outputs_d = self.discriminator(mu_maps_fake)
-            loss_g = self.criterion_g(
-                mu_maps_real, mu_maps_fake, labels_real, outputs_d
-            )
+            inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
+            outputs_d_fake = self.discriminator(inputs_d_fake)
+            loss_g_adv = self.criterion_d(outputs_d_fake, labels_real)
+            loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real)
+            loss_g = loss_g_adv + 100.0 * loss_g_l1
             loss_g.backward()
             self.optimizer_g.step()
 
@@ -416,7 +421,7 @@ if __name__ == "__main__":
     torch.manual_seed(args.seed)
     np.random.seed(args.seed)
 
-    discriminator = Discriminator(in_channels=1, input_size=args.patch_size)
+    discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
     discriminator = discriminator.to(device)
 
     generator = UNet(in_channels=1, features=args.features)
@@ -460,10 +465,10 @@ if __name__ == "__main__":
         data_loaders=data_loaders,
         epochs=args.epochs,
         device=device,
-        lr_d=0.0005,
+        lr_d=0.0002,
         lr_decay_factor_d=0.99,
         lr_decay_epoch_d=1,
-        lr_g=0.001,
+        lr_g=0.0002,
         lr_decay_factor_g=0.99,
         lr_decay_epoch_g=1,
         l2_weight=args.mse_loss_weight,