diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py
index 651325b1b2134de50e19eaec164e2a81a185a1ad..385ca93470888ba413da1701b5680dafd82d0dc2 100644
--- a/mu_map/dataset/default.py
+++ b/mu_map/dataset/default.py
@@ -76,7 +76,7 @@ class MuMapDataset(Dataset):
         self.reconstructions = {}
         self.mu_maps = {}
 
-    def split_copy(self, split_name: str) -> MuMapDataset:
+    def split_copy(self, split_name: str):
         return MuMapDataset(
             dataset_dir=self.dir,
             csv_file=os.path.basename(self.csv_file),
diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py
index fbed83e5aa62621aa94294ef69380648fcce998d..74872696de73f18bfaca7e73390cec811fd78a21 100644
--- a/mu_map/dataset/patches.py
+++ b/mu_map/dataset/patches.py
@@ -45,7 +45,7 @@ class MuMapPatchDataset(MuMapDataset):
         self.patches = []
         self.generate_patches()
 
-    def split_copy(self, split_name: str) ->MuMapPatchDataset:
+    def split_copy(self, split_name: str):
         return MuMapPatchDataset(
             dataset_dir=self.dir,
             patches_per_image=self.patches_per_image,
diff --git a/mu_map/training/default.py b/mu_map/training/distance.py
similarity index 83%
rename from mu_map/training/default.py
rename to mu_map/training/distance.py
index 14898a06ac0a5f88ed39a6f344ac4e1e0d0e0e0e..e607dffd86d9227e0da770b467233c1e37db8f53 100644
--- a/mu_map/training/default.py
+++ b/mu_map/training/distance.py
@@ -3,6 +3,7 @@ from typing import Dict
 
 import torch
 
+from mu_map.dataset.default import MuMapDataset
 from mu_map.logging import get_logger
 from mu_map.training.lib import TrainingParams, AbstractTraining
 from mu_map.training.loss import WeightedLoss
@@ -17,25 +18,25 @@ class Training(AbstractTraining):
         device: torch.device,
         snapshot_dir: str,
         snapshot_epoch: int,
-
         params: TrainingParams,
         loss_func: WeightedLoss,
+        logger,
     ):
-        super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch)
+        super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger)
         self.training_params.append(params)
 
         self.loss_func = loss_func
         self.model = params.model
 
     def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
-        outputs = self.model(inputs)
-        loss = self.loss_func(outputs, mu_maps)
+        mu_maps_predicted = self.model(recons)
+        loss = self.loss_func(mu_maps_predicted, mu_maps)
         loss.backward()
         return loss.item()
 
     def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
-        outpus = self.model(inputs)
-        loss = torch.nn.functional.loss.l1(outpus, mu_maps)
+        mu_maps_predicted = self.model(recons)
+        loss = torch.nn.functional.loss.l1(mu_maps_predicted, mu_maps)
         return loss.item()
 
 
@@ -144,6 +145,11 @@ if __name__ == "__main__":
         default="l1",
         help="define the loss function used for training, e.g. 0.75*l1+0.25*gdl",
     )
+    parser.add_argument(
+        "--decay_lr",
+        action="store_true",
+        help="decay the learning rate",
+    )
     parser.add_argument(
         "--lr", type=float, default=0.001, help="the initial learning rate for training"
     )
@@ -203,8 +209,6 @@ if __name__ == "__main__":
     torch.manual_seed(args.seed)
     np.random.seed(args.seed)
 
-    model = UNet(in_channels=1, features=args.features)
-    model = model.to(device)
 
     transform_normalization = None
     if args.input_norm == "mean":
@@ -214,41 +218,38 @@ if __name__ == "__main__":
     elif args.input_norm == "gaussian":
         transform_normalization = GaussianNormTransform()
 
-    data_loaders = {}
-    for split in ["train", "validation"]:
-        dataset = MuMapPatchDataset(
-            args.dataset_dir,
-            patches_per_image=args.number_of_patches,
-            patch_size=args.patch_size,
-            patch_offset=args.patch_offset,
-            shuffle=not args.no_shuffle,
-            split_name=split,
-            transform_normalization=transform_normalization,
-            logger=logger,
-        )
-        data_loader = torch.utils.data.DataLoader(
-            dataset=dataset,
-            batch_size=args.batch_size,
-            shuffle=True,
-            pin_memory=True,
-            num_workers=1,
+    dataset = MuMapPatchDataset(
+        args.dataset_dir,
+        patches_per_image=args.number_of_patches,
+        patch_size=args.patch_size,
+        patch_offset=args.patch_offset,
+        shuffle=not args.no_shuffle,
+        transform_normalization=transform_normalization,
+        logger=logger,
+    )
+
+    model = UNet(in_channels=1, features=args.features)
+    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
+    lr_scheduler = (
+        torch.optim.lr_scheduler.StepLR(
+            optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
         )
-        data_loaders[split] = data_loader
+        if args.decay_lr
+        else None
+    )
+    params = TrainingParams(name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)
 
     criterion = WeightedLoss.from_str(args.loss_func)
-    logger.debug(f"Criterion: {criterion}")
 
     training = Training(
-        model=model,
-        data_loaders=data_loaders,
         epochs=args.epochs,
+        dataset=dataset,
+        batch_size=args.batch_size,
         device=device,
-        loss_func=criterion,
-        lr=args.lr,
-        lr_decay_factor=args.lr_decay_factor,
-        lr_decay_epoch=args.lr_decay_epoch,
         snapshot_dir=args.snapshot_dir,
         snapshot_epoch=args.snapshot_epoch,
+        params=params,
+        loss_func=criterion,
         logger=logger,
     )
     training.run()
diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
index 00811cd63d651220f13f9fbf857536b63a5d872a..dc92337f1d3a28aea466bd39595d26bd5d09bcf6 100644
--- a/mu_map/training/lib.py
+++ b/mu_map/training/lib.py
@@ -1,12 +1,14 @@
 from dataclasses import dataclass
 import os
-from typing import Dict
+from typing import Dict, Optional
+import sys
 
 import torch
 from torch import Tensor
 
 from mu_map.dataset.default import MuMapDataset
 
+
 @dataclass
 class TrainingParams:
     name: str
@@ -14,8 +16,8 @@ class TrainingParams:
     optimizer: torch.optim.Optimizer
     lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
 
-class AbstractTraining:
 
+class AbstractTraining:
     def __init__(
         self,
         epochs: int,
@@ -24,7 +26,7 @@ class AbstractTraining:
         device: torch.device,
         snapshot_dir: str,
         snapshot_epoch: int,
-        logger=None,
+        logger,  # TODO make optional?
     ):
         self.epochs = epochs
         self.batch_size = batch_size
@@ -37,7 +39,21 @@ class AbstractTraining:
         self.logger = logger
 
         self.training_params = []
-
+        self.data_loaders = dict(
+            [
+                (
+                    split_name,
+                    torch.utils.data.DataLoader(
+                        dataset.split_copy(split_name),
+                        batch_size=self.batch_size,
+                        shuffle=True,
+                        pin_memory=True,
+                        num_workers=1,
+                    ),
+                )
+                for split_name in ["train", "validation"]
+            ]
+        )
 
     def run(self) -> float:
         loss_val_min = sys.maxsize
@@ -88,7 +104,7 @@ class AbstractTraining:
             for param in self.training_params:
                 param.optimizer.zero_grad()
 
-            loss = loss + self._train_batch(self, inputs, targets)
+            loss = loss + self._train_batch(inputs, targets)
 
             for param in self.training_params:
                 param.optimizer.step()
@@ -110,12 +126,14 @@ class AbstractTraining:
             inputs = inputs.to(self.device)
             targets = targets.to(self.device)
 
-            loss = loss + self._eval_batch(self, inputs, targets)
+            loss = loss + self._eval_batch(inputs, targets)
         return loss / len(data_loader)
 
     def store_snapshot(self, prefix: str):
         for param in self.training_params:
-            snapshot_file = os.path.join(self.snapshot_dir, f"{prefix}_{param.name}.pth")
+            snapshot_file = os.path.join(
+                self.snapshot_dir, f"{prefix}_{param.name}.pth"
+            )
             self.logger.debug(f"Store snapshot at {snapshot_file}")
             torch.save(param.model.state_dict(), snapshot_file)
 
@@ -124,4 +142,3 @@ class AbstractTraining:
 
     def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
         return 0
-