From afdee20bd29ef751cf69d03200a75aabc6b23d96 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Mon, 10 Oct 2022 11:12:38 +0200 Subject: [PATCH] make loss weights in cgan loss configurable --- mu_map/training/cgan.py | 49 ++++++++++++++++++++++++++++++++++------- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 1716e26..25d9fc2 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -14,7 +14,11 @@ LABEL_FAKE = 0.0 class GeneratorLoss(torch.nn.Module): def __init__( - self, l2_weight: float = 1.0, gdl_weight: float = 1.0, adv_weight: float = 20.0 + self, + l2_weight: float = 1.0, + gdl_weight: float = 1.0, + adv_weight: float = 20.0, + logger=None, ): super().__init__() @@ -27,6 +31,12 @@ class GeneratorLoss(torch.nn.Module): self.adv = torch.nn.MSELoss(reduction="mean") self.adv_weight = adv_weight + if logger: + logger.debug(f"GeneratorLoss: {self}") + + def __repr__(self): + return f"{self.l2_weight:.3f} * MSELoss + {self.gdl_weight:.3f} * GDLLoss + {self.adv_weight:.3f} * AdversarialLoss" + def forward( self, mu_maps_real: Tensor, @@ -94,7 +104,10 @@ class cGANTraining: self.criterion_d = torch.nn.MSELoss(reduction="mean") self.criterion_g = GeneratorLoss( - l2_weight=l2_weight, gdl_weight=gdl_weight, adv_weight=adv_weight + l2_weight=l2_weight, + gdl_weight=gdl_weight, + adv_weight=adv_weight, + logger=self.logger, ) def run(self): @@ -174,7 +187,9 @@ class cGANTraining: # update generator outputs_d = self.discriminator(mu_maps_fake) - loss_g = self.criterion_g(mu_maps_real, mu_maps_fake, labels_real, outputs_d) + loss_g = self.criterion_g( + mu_maps_real, mu_maps_fake, labels_real, outputs_d + ) loss_g.backward() self.optimizer_g.step() @@ -189,7 +204,7 @@ class cGANTraining: loss = 0.0 updates = 0 - + data_loader = self.data_loaders[split_name] for i, (recons, mu_maps) in enumerate(data_loader): print( @@ -319,6 +334,24 @@ if __name__ == "__main__": default="cuda:0" if torch.cuda.is_available() else "cpu", help="the device (cpu or gpu) with which the training is performed", ) + parser.add_argument( + "--mse_loss_weight", + type=float, + default=1.0, + help="weight for the L2-Loss of the generator", + ) + parser.add_argument( + "--gdl_loss_weight", + type=float, + default=1.0, + help="weight for the Gradient-Difference-Loss of the generator", + ) + parser.add_argument( + "--adv_loss_weight", + type=float, + default=20.0, + help="weight for the Adversarial-Loss of the generator", + ) parser.add_argument( "--lr", type=float, default=0.001, help="the initial learning rate for training" ) @@ -372,7 +405,7 @@ if __name__ == "__main__": logger = get_logger_by_args(args) logger.info(args) - args.seed = args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1) + args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {args.seed}") random.seed(args.seed) torch.manual_seed(args.seed) @@ -425,9 +458,9 @@ if __name__ == "__main__": lr_g=0.001, lr_decay_factor_g=0.99, lr_decay_epoch_g=1, - l2_weight=0.25, - gdl_weight=0.25, - adv_weight=0.5, + l2_weight=args.mse_loss_weight, + gdl_weight=args.gdl_loss_weight, + adv_weight=args.adv_loss_weight, snapshot_dir=args.snapshot_dir, snapshot_epoch=args.snapshot_epoch, logger=logger, -- GitLab