diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index 05de5caee053aa089ded8f1551e3b8c146814b23..289fb248210ba8894f05cfb7c3910e0eac3e5730 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -60,7 +60,7 @@ class Discriminator(nn.Module): nn.Linear(in_features=512, out_features=128), nn.ReLU(inplace=True), nn.Linear(in_features=128, out_features=1), - nn.Sigmoid(), + # nn.Sigmoid(), ) def forward(self, x: torch.Tensor): diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 6352b21c27efe0bf8af81a78cb108aec329ca6fb..3aa83ec2d64990ffe46b7e5593402517648b1cff 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -15,14 +15,15 @@ LABEL_FAKE = 0.0 class GeneratorLoss(torch.nn.Module): def __init__( self, - l2_weight: float = 1.0, - gdl_weight: float = 1.0, - adv_weight: float = 20.0, - logger=None, + # l2_weight: float = 1.0, + # gdl_weight: float = 1.0, + # adv_weight: float = 20.0, + # logger=None, ): super().__init__() - self.l2 = torch.nn.MSELoss(reduction="mean") + # self.l2 = torch.nn.MSELoss(reduction="mean") + self.l2 = torch.nn.L1Loss(reduction="mean") self.l2_weight = l2_weight self.gdl = GradientDifferenceLoss() @@ -88,19 +89,19 @@ class cGANTraining: self.logger = logger if logger is not None else get_logger() - self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d) - self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g) + self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999)) + self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g, betas=(0.5, 0.999)) - self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR( - self.optimizer_d, - step_size=lr_decay_epoch_d, - gamma=lr_decay_factor_d, - ) - self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR( - self.optimizer_g, - step_size=lr_decay_epoch_g, - gamma=lr_decay_factor_g, - ) + # self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR( + # self.optimizer_d, + # step_size=lr_decay_epoch_d, + # gamma=lr_decay_factor_d, + # ) + # self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR( + # self.optimizer_g, + # step_size=lr_decay_epoch_g, + # gamma=lr_decay_factor_g, + # ) self.criterion_d = torch.nn.MSELoss(reduction="mean") self.criterion_g = GeneratorLoss( @@ -109,6 +110,7 @@ class cGANTraining: adv_weight=adv_weight, logger=self.logger, ) + self.criterion_l1 = torch.nn.L1Loss(reduction="mean") def run(self): losses_d = [] @@ -124,8 +126,8 @@ class cGANTraining: self._eval_epoch(epoch, "train") self._eval_epoch(epoch, "validation") - self.lr_scheduler_d.step() - self.lr_scheduler_g.step() + # self.lr_scheduler_d.step() + # self.lr_scheduler_g.step() if epoch % self.snapshot_epoch == 0: self.store_snapshot(epoch) @@ -162,34 +164,37 @@ class cGANTraining: def _step(self, recons, mu_maps_real): batch_size = recons.shape[0] - - self.optimizer_d.zero_grad() - self.optimizer_g.zero_grad() - labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device) labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device) with torch.set_grad_enabled(True): + self.optimizer_d.zero_grad() + self.optimizer_g.zero_grad() + # compute fake mu maps with generator mu_maps_fake = self.generator(recons) - # update discriminator based on real mu maps - outputs_d = self.discriminator(mu_maps_real) - loss_d_real = self.criterion_d(outputs_d, labels_real) - loss_d_real.backward() # compute gradients - # update discriminator based on fake mu maps - outputs_d = self.discriminator( - mu_maps_fake.detach() - ) # note the detach, so that gradients are not computed for the generator - loss_d_fake = self.criterion_d(outputs_d, labels_fake) - loss_d_fake.backward() # compute gradients - self.optimizer_d.step() # update discriminator based on gradients + # compute discriminator loss for fake mu maps + inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) + outputs_d_fake = self.discriminator(inputs_d_fake.detach()) # note the detach, so that gradients are not computed for the generator + loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake) + + # compute discriminator loss for real mu maps + inputs_d_real = torch.cat((recons, mu_maps), dim=1) + outputs_d_real = self.discriminator(inputs_d_real) # note the detach, so that gradients are not computed for the generator + loss_d_real = self.criterion_d(outputs_d_real, labels_real) + + # update discriminator + loss_d = 0.5 * (loss_d_fake + loss_d_real) + loss_d.backward() # compute gradients + self.optimizer_d.step() # update generator - outputs_d = self.discriminator(mu_maps_fake) - loss_g = self.criterion_g( - mu_maps_real, mu_maps_fake, labels_real, outputs_d - ) + inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) + outputs_d_fake = self.discriminator(inputs_d_fake) + loss_g_adv = self.criterion_d(outputs_d_fake, labels_real) + loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real) + loss_g = loss_g_adv + 100.0 * loss_g_l1 loss_g.backward() self.optimizer_g.step() @@ -416,7 +421,7 @@ if __name__ == "__main__": torch.manual_seed(args.seed) np.random.seed(args.seed) - discriminator = Discriminator(in_channels=1, input_size=args.patch_size) + discriminator = Discriminator(in_channels=2, input_size=args.patch_size) discriminator = discriminator.to(device) generator = UNet(in_channels=1, features=args.features) @@ -460,10 +465,10 @@ if __name__ == "__main__": data_loaders=data_loaders, epochs=args.epochs, device=device, - lr_d=0.0005, + lr_d=0.0002, lr_decay_factor_d=0.99, lr_decay_epoch_d=1, - lr_g=0.001, + lr_g=0.0002, lr_decay_factor_g=0.99, lr_decay_epoch_g=1, l2_weight=args.mse_loss_weight,