diff --git a/mu_map/training/default.py b/mu_map/training/default.py
index 15330982e772dcb3298f0a99a6b8e77adebc0a97..d0d627b663b1893fe0c7d7ad7c4c7de57ded6ba1 100644
--- a/mu_map/training/default.py
+++ b/mu_map/training/default.py
@@ -1,44 +1,97 @@
+import os
 
+import torch
 
 class Training():
 
-    def __init__(self, epochs):
+    def __init__(self, model, data_loaders, epochs, logger):
+        self.model = model
+        self.data_loaders = data_loaders
         self.epochs = epochs
+        self.device = torch.device("cpu")
+        self.snapshot_dir = "tmp"
+        self.snapshot_epoch = 5
+        self.loss_func = torch.nn.MSELoss()
+
+        # self.lr = 1e-3
+        # self.lr_decay_factor = 0.99
+        self.lr = 0.1
+        self.lr_decay_factor = 0.5
+        self.lr_decay_epoch = 1
+
+        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
+        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor)
+
+        self.logger = logger
+
 
     def run(self):
         for epoch in range(1, self.epochs + 1):
-            self.run_epoch(self.data_loader["train"], phase="train")
-            loss_training = self.run_epoch(self.data_loader["train"], phase="eval")
-            loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval")
+            logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...")
+            self._run_epoch(self.data_loaders["train"], phase="train")
+
+            loss_training = self._run_epoch(self.data_loaders["train"], phase="eval")
+            logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}")
+            loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval")
+            logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}")
 
             # ToDo: log outputs and time
+            _previous = self.lr_scheduler.get_last_lr()[0]
             self.lr_scheduler.step()
+            logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}")
 
             if epoch % self.snapshot_epoch:
                 self.store_snapshot(epoch)
 
+            logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}")
+
             
 
-    def run_epoch(self, data_loader, phase):
+    def _run_epoch(self, data_loader, phase):
+        logger.debug(f"Run epoch in phase {phase}")
         self.model.train() if phase == "train" else self.model.eval()
 
         epoch_loss = 0
-        for inputs, labels in self.data_loader:
+        for i, (inputs, labels) in enumerate(data_loader):
+            print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r")
             inputs = inputs.to(self.device)
             labels = labels.to(self.device)
 
             self.optimizer.zero_grad()
             with torch.set_grad_enabled(phase == "train"):
                 outputs = self.model(inputs)
-                loss = self.loss(outputs, labels)
+                loss = self.loss_func(outputs, labels)
 
                 if phase == "train":
                     loss.backward()
-                    optimizer.step()
+                    self.optimizer.step()
 
-            epoch_loss += loss.item() / inputs.size[0]
+            epoch_loss += loss.item() / inputs.shape[0]
         return epoch_loss
 
 
     def store_snapshot(self, epoch):
-        pass
+        snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
+        snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
+        logger.debug(f"Store snapshot at {snapshot_file}")
+        torch.save(self.model.state_dict(), snapshot_file)
+
+
+if __name__ == "__main__":
+    from mu_map.data.mock import MuMapMockDataset
+    from mu_map.logging import get_logger
+    from mu_map.models.unet import UNet
+
+    logger = get_logger(logfile="train.log", loglevel="DEBUG")
+
+    model = UNet(in_channels=1, features=[8, 16])
+    print(model)
+    dataset = MuMapMockDataset()
+    data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
+    data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
+    data_loaders = {"train": data_loader_train, "validation": data_loader_val}
+
+    training = Training(model, data_loaders, 10, logger)
+    training.run()
+
+