From b08f5d081d04fb984437c0e9eaad918b8c74878d Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 13 Jan 2023 10:15:29 +0100 Subject: [PATCH] add early stopping param to cgan training and update doc --- mu_map/training/cgan.py | 59 ++++++++++++++++++++++++++++++++--------- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 8880c46..d5bad7d 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -17,6 +17,7 @@ class DiscriminatorParams(TrainingParams): """ Wrap training parameters to always carry the name 'Discriminator'. """ + def __init__( self, model: torch.nn.Module, @@ -30,10 +31,12 @@ class DiscriminatorParams(TrainingParams): lr_scheduler=lr_scheduler, ) + class GeneratorParams(TrainingParams): """ Wrap training parameters to always carry the name 'Generator'. """ + def __init__( self, model: torch.nn.Module, @@ -48,11 +51,26 @@ class GeneratorParams(TrainingParams): ) - class cGANTraining(AbstractTraining): """ Implementation of a conditional generative adversarial network training. + + To see all parameters, have a look at AbstractTraining. + + Parameters + ---------- + params_generator: GeneratorParams + training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator + params_discriminator: DiscriminatorParams + training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator + loss_func_dist: WeightedLoss + distance loss function for the generator + weight_criterion_dist: float + weight of the distance loss when training the generator + weight_criterion_adv: float + weight of the adversarial loss when training the generator """ + def __init__( self, epochs: int, @@ -66,17 +84,18 @@ class cGANTraining(AbstractTraining): loss_func_dist: WeightedLoss, weight_criterion_dist: float, weight_criterion_adv: float, + early_stopping: Optional[int], logger: Optional[Logger] = None, ): - """ - :param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator - :param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator - :param loss_func_dist: distance loss function for the generator - :param weight_criterion_dist: weight of the distance loss when training the generator - :param weight_criterion_adv: weight of the adversarial loss when training the generator - """ super().__init__( - epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger + epochs=epochs, + dataset=dataset, + batch_size=batch_size, + device=device, + early_stopping=early_stopping, + snapshot_dir=snapshot_dir, + snapshot_epoch=snapshot_epoch, + logger=logger, ) self.training_params.append(params_generator) self.training_params.append(params_discriminator) @@ -102,7 +121,7 @@ class cGANTraining(AbstractTraining): pass def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: - mu_maps_real = mu_maps # rename real mu maps for clarification + mu_maps_real = mu_maps # rename real mu maps for clarification # compute fake mu maps with generator mu_maps_fake = self.generator(recons) @@ -111,8 +130,16 @@ class cGANTraining(AbstractTraining): inputs_d_real = torch.cat((recons, mu_maps_real), dim=1) # prepare labels/targets for the discriminator - labels_fake = torch.full(self.discriminator.get_output_shape(inputs_d_fake.shape), LABEL_FAKE, device=self.device) - labels_real = torch.full(self.discriminator.get_output_shape(inputs_d_real.shape), LABEL_REAL, device=self.device) + labels_fake = torch.full( + self.discriminator.get_output_shape(inputs_d_fake.shape), + LABEL_FAKE, + device=self.device, + ) + labels_real = torch.full( + self.discriminator.get_output_shape(inputs_d_real.shape), + LABEL_REAL, + device=self.device, + ) # ======================= Discriminator ===================================== # compute discriminator loss for fake mu maps @@ -131,7 +158,7 @@ class cGANTraining(AbstractTraining): # =========================================================================== # ======================= Generator ========================================= - outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach + outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real) loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real) loss_g = ( @@ -256,6 +283,11 @@ if __name__ == "__main__": default="cuda:0" if torch.cuda.is_available() else "cpu", help="the device (cpu or gpu) with which the training is performed", ) + parser.add_argument( + "--early_stopping", + type=int, + help="define early stopping as the least amount of epochs in which the validation loss must improve", + ) parser.add_argument( "--dist_loss_func", type=str, @@ -400,6 +432,7 @@ if __name__ == "__main__": dataset=dataset, batch_size=args.batch_size, device=device, + early_stopping=args.early_stopping, snapshot_dir=args.snapshot_dir, snapshot_epoch=args.snapshot_epoch, params_generator=params_g, -- GitLab