diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index bbec3236882e30797f5269d32092e24b3e2dfdec..722d1a4828288c868c1b15efbc884c8626f5be2a 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -1,3 +1,6 @@
+"""
+Implementation of a cGAN training.
+"""
 from logging import Logger
 from typing import Optional
 
@@ -53,22 +56,7 @@ class GeneratorParams(TrainingParams):
 
 class cGANTraining(AbstractTraining):
     """
-    Implementation of a conditional generative adversarial network training.
-
-    To see all parameters, have a look at AbstractTraining.
-
-    Parameters
-    ----------
-    params_generator: GeneratorParams
-        training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
-    params_discriminator: DiscriminatorParams
-        training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
-    loss_func_dist: WeightedLoss
-        distance loss function for the generator
-    weight_criterion_dist: float
-        weight of the distance loss when training the generator
-    weight_criterion_adv: float
-        weight of the adversarial loss when training the generator
+    Implementation of a conditional generative adversarial network (cGAN) training.
     """
 
     def __init__(
@@ -97,6 +85,26 @@ class cGANTraining(AbstractTraining):
             snapshot_epoch=snapshot_epoch,
             logger=logger,
         )
+        """
+        Initialize a cGAN training.
+
+        Parameters not described here are passed to the AbstractTraining super class.
+
+        Parameters
+        ----------
+        params_generator: GeneratorParams
+            training parameters containing a model an according optimizer and optionally a
+            learning rate scheduler for the generator
+        params_discriminator: DiscriminatorParams
+            training parameters containing a model an according optimizer and optionally a 
+            learning rate scheduler for the discriminator
+        loss_func_dist: WeightedLoss
+            distance loss function for the generator
+        weight_criterion_dist: float
+            weight of the distance loss when training the generator
+        weight_criterion_adv: float
+            weight of the adversarial loss when training the generator
+        """
         self.training_params.append(params_generator)
         self.training_params.append(params_discriminator)
 
@@ -114,9 +122,10 @@ class cGANTraining(AbstractTraining):
 
     def _after_train_batch(self):
         """
-        Overwrite calling step on all optimizers as this needs to be done
-        separately for the generator and discriminator during the training of
-        a batch.
+        Overwrite this function so that `optimizer.step()` is not called.
+
+        This needs do be done separately for the generator and discriminator
+        during the training of a batch.
         """
         pass
 
diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py
index df2dfb95d9f5dff506e23673b3dcfea6388c7f84..aa54eb0f8b2189d6db23a1ec9d034a9c8e04cbb4 100644
--- a/mu_map/training/distance.py
+++ b/mu_map/training/distance.py
@@ -1,3 +1,6 @@
+"""
+Implementation of training based on a distance loss.
+"""
 from logging import Logger
 from typing import Optional
 
@@ -12,15 +15,6 @@ class DistanceTraining(AbstractTraining):
     """
     Implementation of a distance training: a model predicts a mu map
     from a reconstruction by optimizing a distance loss (e.g. L1).
-
-    To see all parameters, have a look at AbstractTraining.
-
-    Parameters
-    ----------
-    params: TrainingParams
-        training parameters containing a model an according optimizer and optionally a learning rate scheduler
-    loss_func: WeightedLoss
-        the distance loss function
     """
 
     def __init__(
@@ -36,6 +30,18 @@ class DistanceTraining(AbstractTraining):
         early_stopping: Optional[int] = None,
         logger: Optional[Logger] = None,
     ):
+        """
+        Initialize a distance training.
+
+        Parameters not described here are passed to the AbstractTraining super class.
+
+        Parameters
+        ----------
+        params: TrainingParams
+            training parameters containing a model an according optimizer and optionally a learning rate scheduler
+        loss_func: WeightedLoss
+            the distance loss function
+        """
         super().__init__(
             epochs=epochs,
             dataset=dataset,
diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
index 57bb8473a34ba64c2862534bb0dab9bad261a31e..576eb03be52861512803c9444f35538c740e6a4c 100644
--- a/mu_map/training/lib.py
+++ b/mu_map/training/lib.py
@@ -18,7 +18,7 @@ from mu_map.logging import get_logger
 
 def init_random_seed(seed: Optional[int] = None) -> int:
     """
-    Set the seed for all RNGs (default python, numpy and torch).
+    Set the seed for all RNGs (python, numpy and torch).
 
     Parameters
     ----------
@@ -28,7 +28,7 @@ def init_random_seed(seed: Optional[int] = None) -> int:
     Returns
     -------
     int
-        the randoms seed used
+        the random seed used
     """
     seed = seed if seed is not None else random.randint(0, 2**32 - 1)
 
@@ -64,26 +64,6 @@ class AbstractTraining:
     This abstract class implement a common training procedure so that
     implementations can focus on the computations per batch and not iterating over
     the dataset, storing snapshots, etc.
-
-    Parameters
-    ----------
-    epochs: int
-        the number of epochs to train
-    dataset: MuMapDataset
-        the dataset to use for training
-    batch_size: int
-        the batch size used for training
-    device: torch.device
-        the device on which to perform computations (cpu or cuda)
-    snapshot_dir: str
-        the directory where snapshots are stored
-    snapshot_epoch: int
-        at each of these epochs a snapshot is stored
-    early_stopping: int, optional
-        if defined, training is stopped if the validation loss did not improve
-        for this many epochs
-    logger: Logger, optional
-        optional logger to print results
     """
 
     def __init__(
@@ -97,6 +77,29 @@ class AbstractTraining:
         early_stopping: Optional[int] = None,
         logger: Optional[Logger] = None,
     ):
+        """
+        Initialize the training.
+
+        Parameters
+        ----------
+        epochs: int
+            the number of epochs to train
+        dataset: MuMapDataset
+            the dataset to use for training
+        batch_size: int
+            the batch size used for training
+        device: torch.device
+            the device on which to perform computations (cpu or cuda)
+        snapshot_dir: str
+            the directory where snapshots are stored
+        snapshot_epoch: int
+            at each of these epochs a snapshot is stored
+        early_stopping: int, optional
+            if defined, training is stopped if the validation loss did not improve
+            for this many epochs
+        logger: Logger, optional
+            optional logger to print results
+        """
         self.epochs = epochs
         self.batch_size = batch_size
         self.dataset = dataset
@@ -188,7 +191,10 @@ class AbstractTraining:
         """
         Implementation of the training in a single epoch.
 
-            :return: a number representing the training loss
+        Returns
+        -------
+        float
+            a number representing the training loss
         """
         # activate gradients
         torch.set_grad_enabled(True)
@@ -223,7 +229,10 @@ class AbstractTraining:
         """
         Implementation of the evaluation in a single epoch.
 
-        :return: a number representing the validation loss
+        Returns
+        -------
+        float
+            a number representing the validation loss
         """
         # deactivate gradients
         torch.set_grad_enabled(False)
diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py
index 7272b3b2b4f10f3aec4f5ef06e2a44c20eca4412..a0297f121b3bfa5d033c0cf846453bf8dc3a1ac2 100644
--- a/mu_map/training/loss.py
+++ b/mu_map/training/loss.py
@@ -1,3 +1,6 @@
+"""
+Implementations of different loss functions.
+"""
 from typing import Any, List
 
 import torch
@@ -24,8 +27,15 @@ def loss_by_string(loss_str: str) -> nn.Module:
     Retrieve a loss function defined by a string.
     E.g., L1 returns the torch module of the l1 loss function.
 
-    :param loss_str: loss function defined as a string
-    :returns: an executable loss function
+    Parameters
+    ----------
+    loss_str: str
+        loss function defined as a string
+
+    Returns
+    -------
+    nn.Module
+        a callable loss function
     """
     loss_str = loss_str.lower()
     if "l1" in loss_str:
@@ -42,12 +52,20 @@ class WeightedLoss(nn.Module):
     """
     Definition of a weighted loss consisting of a number of losses
     with according weights.
-
-    :param losses: the losses to be summed and weighted
-    :param weights: weights for each loss function
     """
 
     def __init__(self, losses: List[nn.Module], weights: List[float]):
+        """
+        Initialize a weighted loss.
+
+
+        Parameters
+        ----------
+        losses: list of nn.Module
+            list of loss functions
+        weights: list of float
+            weights for each loss function
+        """
         super().__init__()
 
         assert len(losses) == len(