diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
index e820c07f6e0fcf9329c21e249a3e9a31336ac728..eee9a09bd886e2a96d3a96acfb85907c9369a3aa 100644
--- a/mu_map/training/lib.py
+++ b/mu_map/training/lib.py
@@ -1,4 +1,5 @@
 from dataclasses import dataclass
+from logging import Logger
 import os
 from typing import Dict, Optional
 import sys
@@ -7,6 +8,7 @@ import torch
 from torch import Tensor
 
 from mu_map.dataset.default import MuMapDataset
+from mu_map.logging import get_logger
 
 
 @dataclass
@@ -26,7 +28,7 @@ class AbstractTraining:
         device: torch.device,
         snapshot_dir: str,
         snapshot_epoch: int,
-        logger,  # TODO make optional?
+        logger: Optional[Logger],
     ):
         self.epochs = epochs
         self.batch_size = batch_size
@@ -36,7 +38,7 @@ class AbstractTraining:
         self.snapshot_dir = snapshot_dir
         self.snapshot_epoch = snapshot_epoch
 
-        self.logger = logger
+        self.logger = logger if logger is not None else get_logger(name=self.__class__.__name__)
 
         self.training_params = []
         self.data_loaders = dict(
@@ -85,6 +87,14 @@ class AbstractTraining:
                     param.lr_scheduler.step()
         return loss_val_min
 
+    def _after_train_batch(self):
+        """
+        Function called after the loss computation on a batch during training.
+        It is responsible for stepping all optimizers.
+        """
+        for param in self.training_params:
+            param.optimizer.step()
+
     def _train_epoch(self):
         torch.set_grad_enabled(True)
         for param in self.training_params:
@@ -106,8 +116,7 @@ class AbstractTraining:
 
             loss = loss + self._train_batch(inputs, targets)
 
-            for param in self.training_params:
-                param.optimizer.step()
+            self._after_train_batch()
         return loss / len(data_loader)
 
     def _eval_epoch(self):
@@ -132,7 +141,7 @@ class AbstractTraining:
     def store_snapshot(self, prefix: str):
         for param in self.training_params:
             snapshot_file = os.path.join(
-                self.snapshot_dir, f"{prefix}_{param.name}.pth"
+                self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
             )
             self.logger.debug(f"Store snapshot at {snapshot_file}")
             torch.save(param.model.state_dict(), snapshot_file)