diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index bbec3236882e30797f5269d32092e24b3e2dfdec..722d1a4828288c868c1b15efbc884c8626f5be2a 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -1,3 +1,6 @@ +""" +Implementation of a cGAN training. +""" from logging import Logger from typing import Optional @@ -53,22 +56,7 @@ class GeneratorParams(TrainingParams): class cGANTraining(AbstractTraining): """ - Implementation of a conditional generative adversarial network training. - - To see all parameters, have a look at AbstractTraining. - - Parameters - ---------- - params_generator: GeneratorParams - training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator - params_discriminator: DiscriminatorParams - training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator - loss_func_dist: WeightedLoss - distance loss function for the generator - weight_criterion_dist: float - weight of the distance loss when training the generator - weight_criterion_adv: float - weight of the adversarial loss when training the generator + Implementation of a conditional generative adversarial network (cGAN) training. """ def __init__( @@ -97,6 +85,26 @@ class cGANTraining(AbstractTraining): snapshot_epoch=snapshot_epoch, logger=logger, ) + """ + Initialize a cGAN training. + + Parameters not described here are passed to the AbstractTraining super class. + + Parameters + ---------- + params_generator: GeneratorParams + training parameters containing a model an according optimizer and optionally a + learning rate scheduler for the generator + params_discriminator: DiscriminatorParams + training parameters containing a model an according optimizer and optionally a + learning rate scheduler for the discriminator + loss_func_dist: WeightedLoss + distance loss function for the generator + weight_criterion_dist: float + weight of the distance loss when training the generator + weight_criterion_adv: float + weight of the adversarial loss when training the generator + """ self.training_params.append(params_generator) self.training_params.append(params_discriminator) @@ -114,9 +122,10 @@ class cGANTraining(AbstractTraining): def _after_train_batch(self): """ - Overwrite calling step on all optimizers as this needs to be done - separately for the generator and discriminator during the training of - a batch. + Overwrite this function so that `optimizer.step()` is not called. + + This needs do be done separately for the generator and discriminator + during the training of a batch. """ pass diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py index df2dfb95d9f5dff506e23673b3dcfea6388c7f84..aa54eb0f8b2189d6db23a1ec9d034a9c8e04cbb4 100644 --- a/mu_map/training/distance.py +++ b/mu_map/training/distance.py @@ -1,3 +1,6 @@ +""" +Implementation of training based on a distance loss. +""" from logging import Logger from typing import Optional @@ -12,15 +15,6 @@ class DistanceTraining(AbstractTraining): """ Implementation of a distance training: a model predicts a mu map from a reconstruction by optimizing a distance loss (e.g. L1). - - To see all parameters, have a look at AbstractTraining. - - Parameters - ---------- - params: TrainingParams - training parameters containing a model an according optimizer and optionally a learning rate scheduler - loss_func: WeightedLoss - the distance loss function """ def __init__( @@ -36,6 +30,18 @@ class DistanceTraining(AbstractTraining): early_stopping: Optional[int] = None, logger: Optional[Logger] = None, ): + """ + Initialize a distance training. + + Parameters not described here are passed to the AbstractTraining super class. + + Parameters + ---------- + params: TrainingParams + training parameters containing a model an according optimizer and optionally a learning rate scheduler + loss_func: WeightedLoss + the distance loss function + """ super().__init__( epochs=epochs, dataset=dataset, diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index 57bb8473a34ba64c2862534bb0dab9bad261a31e..576eb03be52861512803c9444f35538c740e6a4c 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -18,7 +18,7 @@ from mu_map.logging import get_logger def init_random_seed(seed: Optional[int] = None) -> int: """ - Set the seed for all RNGs (default python, numpy and torch). + Set the seed for all RNGs (python, numpy and torch). Parameters ---------- @@ -28,7 +28,7 @@ def init_random_seed(seed: Optional[int] = None) -> int: Returns ------- int - the randoms seed used + the random seed used """ seed = seed if seed is not None else random.randint(0, 2**32 - 1) @@ -64,26 +64,6 @@ class AbstractTraining: This abstract class implement a common training procedure so that implementations can focus on the computations per batch and not iterating over the dataset, storing snapshots, etc. - - Parameters - ---------- - epochs: int - the number of epochs to train - dataset: MuMapDataset - the dataset to use for training - batch_size: int - the batch size used for training - device: torch.device - the device on which to perform computations (cpu or cuda) - snapshot_dir: str - the directory where snapshots are stored - snapshot_epoch: int - at each of these epochs a snapshot is stored - early_stopping: int, optional - if defined, training is stopped if the validation loss did not improve - for this many epochs - logger: Logger, optional - optional logger to print results """ def __init__( @@ -97,6 +77,29 @@ class AbstractTraining: early_stopping: Optional[int] = None, logger: Optional[Logger] = None, ): + """ + Initialize the training. + + Parameters + ---------- + epochs: int + the number of epochs to train + dataset: MuMapDataset + the dataset to use for training + batch_size: int + the batch size used for training + device: torch.device + the device on which to perform computations (cpu or cuda) + snapshot_dir: str + the directory where snapshots are stored + snapshot_epoch: int + at each of these epochs a snapshot is stored + early_stopping: int, optional + if defined, training is stopped if the validation loss did not improve + for this many epochs + logger: Logger, optional + optional logger to print results + """ self.epochs = epochs self.batch_size = batch_size self.dataset = dataset @@ -188,7 +191,10 @@ class AbstractTraining: """ Implementation of the training in a single epoch. - :return: a number representing the training loss + Returns + ------- + float + a number representing the training loss """ # activate gradients torch.set_grad_enabled(True) @@ -223,7 +229,10 @@ class AbstractTraining: """ Implementation of the evaluation in a single epoch. - :return: a number representing the validation loss + Returns + ------- + float + a number representing the validation loss """ # deactivate gradients torch.set_grad_enabled(False) diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py index 7272b3b2b4f10f3aec4f5ef06e2a44c20eca4412..a0297f121b3bfa5d033c0cf846453bf8dc3a1ac2 100644 --- a/mu_map/training/loss.py +++ b/mu_map/training/loss.py @@ -1,3 +1,6 @@ +""" +Implementations of different loss functions. +""" from typing import Any, List import torch @@ -24,8 +27,15 @@ def loss_by_string(loss_str: str) -> nn.Module: Retrieve a loss function defined by a string. E.g., L1 returns the torch module of the l1 loss function. - :param loss_str: loss function defined as a string - :returns: an executable loss function + Parameters + ---------- + loss_str: str + loss function defined as a string + + Returns + ------- + nn.Module + a callable loss function """ loss_str = loss_str.lower() if "l1" in loss_str: @@ -42,12 +52,20 @@ class WeightedLoss(nn.Module): """ Definition of a weighted loss consisting of a number of losses with according weights. - - :param losses: the losses to be summed and weighted - :param weights: weights for each loss function """ def __init__(self, losses: List[nn.Module], weights: List[float]): + """ + Initialize a weighted loss. + + + Parameters + ---------- + losses: list of nn.Module + list of loss functions + weights: list of float + weights for each loss function + """ super().__init__() assert len(losses) == len(