diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
index 7fc67b1650d896212f1dccac89dac18cb5c00c8e..61d1a960449a35720b12cfddaa5fc0f4b4720d00 100644
--- a/mu_map/training/lib.py
+++ b/mu_map/training/lib.py
@@ -41,13 +41,25 @@ class AbstractTraining:
     implementations can focus on the computations per batch and not iterating over
     the dataset, storing snapshots, etc.
 
-    :param epochs: the number of epochs to train
-    :param dataset: the dataset to use for training
-    :param batch_size: the batch size used for training
-    :param device: the device on which to perform computations (cpu or cuda)
-    :param snapshot_dir: the directory where snapshots are stored
-    :param snapshot_epoch: at each of these epochs a snapshot is stored
-    :param logger: optional logger to print results
+    Parameters
+    ----------
+    epochs: int
+        the number of epochs to train
+    dataset: MuMapDataset
+        the dataset to use for training
+    batch_size: int
+        the batch size used for training
+    device: torch.device
+        the device on which to perform computations (cpu or cuda)
+    snapshot_dir: str
+        the directory where snapshots are stored
+    snapshot_epoch: int
+        at each of these epochs a snapshot is stored
+    early_stopping: int, optional
+        if defined, training is stopped if the validation loss did not improve
+        for this many epochs
+    logger: Logger, optional
+        optional logger to print results
     """
 
     def __init__(
@@ -56,9 +68,9 @@ class AbstractTraining:
         dataset: MuMapDataset,
         batch_size: int,
         device: torch.device,
-        early_stopping: Optional[int],
         snapshot_dir: str,
         snapshot_epoch: int,
+        early_stopping: Optional[int],
         logger: Optional[Logger],
     ):
         self.epochs = epochs
@@ -215,7 +227,10 @@ class AbstractTraining:
         """
         Store snapshots of all models.
 
-        :param prefix: prefix for all stored snapshot files
+        Parameters
+        ----------
+        prefix: str
+            prefix for all stored snapshot files
         """
         for param in self.training_params:
             snapshot_file = os.path.join(
@@ -255,9 +270,17 @@ class AbstractTraining:
         """
         Implementation of training a single batch.
 
-        :param inputs: batch of input data
-        :param targets: batch of target data
-        :return: a number representing the loss
+        Parameters
+        ----------
+        inputs: torch.Tensor
+            batch of input data
+        targets: torch.Tensor
+            batch of target data
+
+        Returns
+        -------
+        float
+            a number representing the loss
         """
         return 0
 
@@ -265,8 +288,16 @@ class AbstractTraining:
         """
         Implementation of evaluating a single batch.
 
-        :param inputs: batch of input data
-        :param targets: batch of target data
-        :return: a number representing the loss
+        Parameters
+        ----------
+        inputs: torch.Tensor
+            batch of input data
+        targets: torch.Tensor
+            batch of target data
+
+        Returns
+        -------
+        float
+            a number representing the loss
         """
         return 0