diff --git a/mu_map/data/preprocessing.py b/mu_map/data/preprocessing.py
deleted file mode 100644
index dc57c6f1de0980238302c460af2d8d9278f8826b..0000000000000000000000000000000000000000
--- a/mu_map/data/preprocessing.py
+++ /dev/null
@@ -1,38 +0,0 @@
-import torch
-
-
-def norm_max(tensor: torch.Tensor):
-    return (tensor - tensor.min()) / (tensor.max() - tensor.min())
-
-
-class MaxNorm:
-    def __call__(self, tensor: torch.Tensor):
-        return norm_max(tensor)
-
-
-def norm_mean(tensor: torch.Tensor):
-    return tensor / tensor.mean()
-
-
-class MeanNorm:
-    def __call__(self, tensor: torch.Tensor):
-        return norm_mean(tensor)
-
-
-def norm_gaussian(tensor: torch.Tensor):
-    return (tensor - tensor.mean()) / tensor.std()
-
-
-class GaussianNorm:
-    def __call__(self, tensor: torch.Tensor):
-        return norm_gaussian(tensor)
-
-
-__all__ = [
-    norm_max.__name__,
-    norm_mean.__name__,
-    norm_gaussian.__name__,
-    MaxNorm.__name__,
-    MeanNorm.__name__,
-    GaussianNorm.__name__,
-]
diff --git a/mu_map/data/split.py b/mu_map/data/split.py
index b21dda72dc67fa462f3409aedf8108a3a205ea44..a685e07a15497ef492180fa2e08f36cb69a210c9 100644
--- a/mu_map/data/split.py
+++ b/mu_map/data/split.py
@@ -1,6 +1,6 @@
 import pandas as pd
 
-from typing import List
+from typing import Dict, List
 
 
 def parse_split_str(_str: str, delimitier: str = "/") -> List[float]:
@@ -23,7 +23,7 @@ def parse_split_str(_str: str, delimitier: str = "/") -> List[float]:
     return split
 
 
-def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]:
+def split_csv(data: pd.DataFrame, split_csv: str) -> Dict[str, pd.DataFrame]:
     """
     Split a data frames based on a file defining a split.
 
@@ -40,7 +40,9 @@ def split_csv(data: pd.DataFrame, split_csv: str) -> List[pd.DataFrame]:
         lambda patient_ids: data[data["patient_id"].isin(patient_ids)],
         split_patient_ids,
     )
-    return list(splits)
+    splits = zip(split_names, splits)
+    splits = dict(splits)
+    return splits
 
 
 if __name__ == "__main__":
diff --git a/mu_map/dataset/__init__.py b/mu_map/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/mu_map/data/datasets.py b/mu_map/dataset/default.py
similarity index 76%
rename from mu_map/data/datasets.py
rename to mu_map/dataset/default.py
index e92ad5acaa1a92f5877fdc36867256d8f9615b24..8291615b2802dc2c23685b18a8bdefac2a170b27 100644
--- a/mu_map/data/datasets.py
+++ b/mu_map/dataset/default.py
@@ -5,11 +5,15 @@ import cv2 as cv
 import pandas as pd
 import pydicom
 import numpy as np
+import torch
 from torch.utils.data import Dataset
 
 from mu_map.data.prepare import headers
 from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
 from mu_map.data.review_mu_map import discard_slices
+from mu_map.data.split import split_csv
+from mu_map.dataset.transform import Transform
+from mu_map.logging import get_logger
 
 
 def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
@@ -43,16 +47,26 @@ class MuMapDataset(Dataset):
         self,
         dataset_dir: str,
         csv_file: str = "meta.csv",
+        split_file: str = "split.csv",
+        split_name: str = None,
         images_dir: str = "images",
         bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
         discard_mu_map_slices: bool = True,
         align: bool = True,
+        transform_normalization: Transform = Transform(),
+        transform_augmentation: Transform = Transform(),
+        logger=None,
     ):
         super().__init__()
 
         self.dir = dataset_dir
         self.dir_images = os.path.join(dataset_dir, images_dir)
         self.csv_file = os.path.join(dataset_dir, csv_file)
+        self.split_file = os.path.join(dataset_dir, split_file)
+
+        self.transform_normalization = transform_normalization
+        self.transform_augmentation = transform_augmentation
+        self.logger = logger if logger is not None else get_logger()
 
         self.bed_contours_file = (
             os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
@@ -63,6 +77,8 @@ class MuMapDataset(Dataset):
 
         # read CSV file and from that access DICOM files
         self.table = pd.read_csv(self.csv_file)
+        if split_name:
+            self.table = split_csv(self.table, self.split_file)[split_name]
         self.table["id"] = self.table["id"].apply(int)
 
         self.discard_mu_map_slices = discard_mu_map_slices
@@ -73,7 +89,7 @@ class MuMapDataset(Dataset):
         self.pre_load_images()
 
     def pre_load_images(self):
-        print("Pre-loading images ...", end="\r")
+        self.logger.debug("Pre-loading images ...")
         for i in range(len(self.table)):
             row = self.table.iloc[i]
             _id = row["id"]
@@ -86,14 +102,25 @@ class MuMapDataset(Dataset):
                 bed_contour = self.bed_contours[row["id"]]
                 for i in range(mu_map.shape[0]):
                     mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
-            self.mu_maps[_id] = mu_map
 
             recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
             recon = pydicom.dcmread(recon_file).pixel_array
             if self.align:
                 recon = align_images(recon, mu_map)
+
+            mu_map = mu_map.astype(np.float32)
+            mu_map = torch.from_numpy(mu_map)
+            mu_map = mu_map.unsqueeze(dim=0)
+
+            recon = recon.astype(np.float32)
+            recon = torch.from_numpy(recon)
+            recon = recon.unsqueeze(dim=0)
+
+            recon, mu_map = self.transform_normalization(recon, mu_map)
+
+            self.mu_maps[_id] = mu_map
             self.reconstructions[_id] = recon
-        print("Pre-loading images done!")
+        self.logger.debug("Pre-loading images done!")
 
     def __getitem__(self, index: int):
         row = self.table.iloc[index]
@@ -102,22 +129,7 @@ class MuMapDataset(Dataset):
         recon = self.reconstructions[_id]
         mu_map = self.mu_maps[_id]
 
-        # recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
-        # mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
-
-        # recon = pydicom.dcmread(recon_file).pixel_array
-        # mu_map = pydicom.dcmread(mu_map_file).pixel_array
-
-        # if self.discard_mu_map_slices:
-        # mu_map = discard_slices(row, mu_map)
-
-        # if self.bed_contours:
-        # bed_contour = self.bed_contours[row["id"]]
-        # for i in range(mu_map.shape[0]):
-        # mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
-
-        # if self.align:
-        # recon = align_images(recon, mu_map)
+        recon, mu_map = self.transform_augmentation(recon, mu_map)
 
         return recon, mu_map
 
@@ -127,46 +139,9 @@ class MuMapDataset(Dataset):
 
 __all__ = [MuMapDataset.__name__]
 
-if __name__ == "__main__":
-    import argparse
-
+def main(dataset):
     from mu_map.util import to_grayscale, COLOR_WHITE
 
-    parser = argparse.ArgumentParser(
-        description="Visualize the images of a MuMapDataset",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-    parser.add_argument(
-        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
-    )
-    parser.add_argument(
-        "--unaligned",
-        action="store_true",
-        help="do not perform center alignment of reconstruction an mu-map slices",
-    )
-    parser.add_argument(
-        "--show_bed",
-        action="store_true",
-        help="do not remove the bed contour from the mu map",
-    )
-    parser.add_argument(
-        "--full_mu_map",
-        action="store_true",
-        help="do not remove broken slices of the mu map",
-    )
-    args = parser.parse_args()
-
-    align = not args.unaligned
-    discard_mu_map_slices = not args.full_mu_map
-    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
-
-    dataset = MuMapDataset(
-        args.dataset_dir,
-        align=align,
-        discard_mu_map_slices=discard_mu_map_slices,
-        bed_contours_file=bed_contours_file,
-    )
-
     wname = "Dataset"
     cv.namedWindow(wname, cv.WINDOW_NORMAL)
     cv.resizeWindow(wname, 1600, 900)
@@ -194,17 +169,20 @@ if __name__ == "__main__":
         im = 0
 
         recon, mu_map = dataset[i]
+        recon = recon.squeeze().numpy()
+        mu_map = mu_map.squeeze().numpy()
         print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")
 
         cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
         key = cv.waitKey(timeout)
 
+        running = 0
         while True:
             ir = (ir + 1) % recon.shape[0]
             im = (im + 1) % mu_map.shape[0]
 
-            cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
-
+            to_show = combine_images((recon, mu_map), (ir, im))
+            cv.imshow(wname, to_show)
             key = cv.waitKey(timeout)
 
             if key == ord("n"):
@@ -218,3 +196,59 @@ if __name__ == "__main__":
             elif key == 81:  # left arrow key
                 ir = max(ir - 2, 0)
                 im = max(im - 2, 0)
+            elif key == ord("s"):
+                cv.imwrite(f"{running:03d}.png", to_show)
+                running += 1
+
+
+if __name__ == "__main__":
+    import argparse
+
+    from mu_map.logging import add_logging_args, get_logger_by_args
+
+    parser = argparse.ArgumentParser(
+        description="Visualize the images of a MuMapDataset",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
+    )
+    parser.add_argument(
+        "--split",
+        type=str,
+        choices=["train", "validation", "test"],
+        help="choose the split of the data for the dataset",
+    )
+    parser.add_argument(
+        "--unaligned",
+        action="store_true",
+        help="do not perform center alignment of reconstruction an mu-map slices",
+    )
+    parser.add_argument(
+        "--show_bed",
+        action="store_true",
+        help="do not remove the bed contour from the mu map",
+    )
+    parser.add_argument(
+        "--full_mu_map",
+        action="store_true",
+        help="do not remove broken slices of the mu map",
+    )
+    add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
+    args = parser.parse_args()
+
+    align = not args.unaligned
+    discard_mu_map_slices = not args.full_mu_map
+    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
+    logger = get_logger_by_args(args)
+
+    dataset = MuMapDataset(
+        args.dataset_dir,
+        align=align,
+        discard_mu_map_slices=discard_mu_map_slices,
+        bed_contours_file=bed_contours_file,
+        split_name=args.split,
+        logger=logger,
+    )
+    main(dataset)
+
diff --git a/mu_map/dataset/mock.py b/mu_map/dataset/mock.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cfd95122c21b805f574abc44a1155fd7a49ffcc
--- /dev/null
+++ b/mu_map/dataset/mock.py
@@ -0,0 +1,27 @@
+from mu_map.dataset.default import MuMapDataset
+from mu_map.dataset.normalization import MaxNormTransform
+
+
+class MuMapMockDataset(MuMapDataset):
+    def __init__(self, dataset_dir: str = "data/initial/", num_images: int = 16, logger=None):
+        super().__init__(dataset_dir=dataset_dir, transform_normalization=MaxNormTransform(), logger=logger)
+        self.len = num_images
+
+    def __getitem__(self, index: int):
+        recon, mu_map = super().__getitem__(0)
+        recon = recon[:, :32, :, :]
+        mu_map = mu_map[:, :32, :, :]
+        return recon, mu_map / 40206.0
+
+    def __len__(self):
+        return self.len
+
+
+if __name__ == "__main__":
+    import cv2 as cv
+    import numpy as np
+
+    from mu_map.dataset.default import main
+
+    dataset = MuMapMockDataset()
+    main(dataset)
diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cfbfc6cbdf31c203ae30e595be956b9c41741c9
--- /dev/null
+++ b/mu_map/dataset/normalization.py
@@ -0,0 +1,53 @@
+from typing import Tuple
+
+from torch import Tensor
+
+from mu_map.dataset.transform import Transform
+
+
+def norm_max(tensor: Tensor) -> Tensor:
+    return (tensor - tensor.min()) / (tensor.max() - tensor.min())
+
+
+class MaxNormTransform(Transform):
+    def __init__(self, max_vals: Tuple[float, float] = None):
+        self.max_vals = max_vals
+
+    def __call__(
+        self, inputs: Tensor, outputs_expected: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        if self.max_vals:
+            return inputs / self.max_vals[0], outputs_expected / self.max_vals[1]
+        return norm_max(inputs), outputs_expected
+
+
+def norm_mean(tensor: Tensor):
+    return tensor / tensor.mean()
+
+
+class MeanNormTransform(Transform):
+    def __call__(
+        self, inputs: Tensor, outputs_expected: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        return norm_mean(inputs), outputs_expected
+
+
+def norm_gaussian(tensor: Tensor):
+    return (tensor - tensor.mean()) / tensor.std()
+
+
+class GaussianNormTransform(Transform):
+    def __call__(
+        self, inputs: Tensor, outputs_expected: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        return norm_gaussian(inputs), outputs_expected
+
+
+__all__ = [
+    norm_max.__name__,
+    norm_mean.__name__,
+    norm_gaussian.__name__,
+    MaxNormTransform.__name__,
+    MeanNormTransform.__name__,
+    GaussianNormTransform.__name__,
+]
diff --git a/mu_map/data/patch_dataset.py b/mu_map/dataset/patches.py
similarity index 64%
rename from mu_map/data/patch_dataset.py
rename to mu_map/dataset/patches.py
index cf6cb786551095c314e88f45ce40818488847f83..df9f2c971b991c2278a7e4af9a6376c601af6d71 100644
--- a/mu_map/data/patch_dataset.py
+++ b/mu_map/dataset/patches.py
@@ -1,7 +1,10 @@
 import math
 import random
 
-from mu_map.data.datasets import MuMapDataset
+import numpy as np
+import torch
+
+from mu_map.dataset.default import MuMapDataset
 
 
 class MuMapPatchDataset(MuMapDataset):
@@ -16,44 +19,44 @@ class MuMapPatchDataset(MuMapDataset):
         self.generate_patches()
 
     def generate_patches(self):
-        for i, (recon, mu_map) in enumerate(zip(self.reconstructions, self.mu_maps)):
+        for _id in self.reconstructions:
+            recon = self.reconstructions[_id].squeeze()
+            mu_map = self.mu_maps[_id].squeeze()
+
             assert (
                 recon.shape[0] == mu_map.shape[0]
             ), f"Reconstruction and MuMap were not aligned for patch dataset"
 
-            _id = self.table.iloc[i]["id"]
-
             z_range = (0, max(recon.shape[0] - self.patch_size, 0))
             # sometimes the mu_maps have fewer than 32 slices
             # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z
-            y_range = (0, recon.shape[1] - self.patch_size)
-            x_range = (0, recon.shape[2] - self.patch_size)
+            y_range = (20, recon.shape[1] - self.patch_size - 20)
+            x_range = (20, recon.shape[2] - self.patch_size - 20)
 
-            padding = [(0, 0), (0, 0), (0, 0)]
+            padding = [0, 0, 0, 0, 0, 0, 0, 0]
             if recon.shape[0] < self.patch_size:
                 diff = self.patch_size - recon.shape[0]
-                padding_bef = math.ceil(diff / 2)
-                padding_aft = math.floor(diff / 2)
-                padding[0] = (padding_bef, padding_aft)
+                padding[4] = math.ceil(diff / 2)
+                padding[5] = math.floor(diff / 2)
 
             for j in range(self.patches_per_image):
                 z = random.randint(*z_range)
                 y = random.randint(*y_range)
                 x = random.randint(*x_range)
-                self.patches.append(_id, z, y, x)
+                self.patches.append((_id, z, y, x, padding))
 
-    def __getitem___(self, index: int):
+    def __getitem__(self, index: int):
         _id, z, y, x, padding = self.patches[index]
-        s = self.patches
+        s = self.patch_size
 
         recon = self.reconstructions[_id]
         mu_map = self.mu_maps[_id]
 
-        recon = np.pad(recon, padding, mode="constant", constant_values=0)
-        mu_map = np.pad(mu_map, padding, mode="constant", constant_values=0)
+        recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
+        mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
 
-        recon = recon[z : z + s, y : y + s, x : x + s]
-        mu_map = mu_map[z : z + s, y : y + s, x : x + s]
+        recon = recon[:, z : z + s, y : y + s, x : x + s]
+        mu_map = mu_map[:, z : z + s, y : y + s, x : x + s]
 
         return recon, mu_map
 
@@ -70,19 +73,18 @@ if __name__ == "__main__":
     cv.namedWindow(wname, cv.WINDOW_NORMAL)
     cv.resizeWindow(wname, 1600, 900)
 
-    dataset = MuMapPatchDataset("data/initial/", patches_per_image=5)
+    dataset = MuMapPatchDataset("data/initial/", patches_per_image=1)
 
     print(f"Images (Patches) in the dataset {len(dataset)}")
 
     def create_image(recon, mu_map, recon_orig, patch, _slice):
-        s = recon.shape[0]
+        s = dataset.patch_size
         _id, _, y, x, padding = patch
 
-        _recon_orig = np.pad(recon_orig, patch, mode="constant", constant_values=0)
         _recon_orig = recon_orig[_slice]
         _recon_orig = to_grayscale(_recon_orig)
         _recon_orig = grayscale_to_rgb(_recon_orig)
-        _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), thickness=1)
+        _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1)
         _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA)
 
         _recon = recon[_slice]
@@ -95,7 +97,7 @@ if __name__ == "__main__":
         _mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA)
         _mu_map = grayscale_to_rgb(_mu_map)
 
-        space = np.full((3, 512, 10), 239, np.uint8)
+        space = np.full((512, 10, 3), 239, np.uint8)
         return np.hstack((_recon, space, _mu_map, space, _recon_orig))
 
     for i in range(len(dataset)):
@@ -104,18 +106,24 @@ if __name__ == "__main__":
         patch = dataset.patches[i]
         _id, z, y, x, padding = patch
         print(
-            "Patch {str(i+1):>len(str(len(dataset)))}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[0][0], padding[0][0]}]"
+            f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[5], padding[6]}]"
         )
         recon, mu_map = dataset[i]
+        recon = recon.squeeze().numpy()
+        mu_map = mu_map.squeeze().numpy()
+
         recon_orig = dataset.reconstructions[_id]
+        recon_orig = torch.nn.functional.pad(recon_orig, padding, mode="constant", value=0)
+        recon_orig = recon_orig.squeeze().numpy()
 
-        cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i))
+        cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
         key = cv.waitKey(100)
 
         while True:
             _i = (_i + 1) % recon.shape[0]
 
-            cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i))
+            cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
+            key = cv.waitKey(100)
 
             if key == ord("n"):
                 break
diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b2685e0cebd568b611c4ac4bf22f69304906b33
--- /dev/null
+++ b/mu_map/dataset/transform.py
@@ -0,0 +1,36 @@
+from typing import List, Tuple
+
+
+from torch import Tensor
+
+
+class Transform:
+    """
+    Interface of a transformer. A transformer can be initialized and then applied to
+    an input tensor and expected output tensor as returned by a dataset. It can be
+    used for normalization and data augmentation.
+    """
+
+    def __call__(
+        self, inputs: Tensor, outputs_expected: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Apply the transformer to a pair of inputs and expected outputs in a dataset.
+        """
+        return inputs, outputs_expected
+
+
+class SequenceTransform(Transform):
+    """
+    A transformer that applies a sequence of transformers sequentially.
+    """
+
+    def __init__(self, transforms: List[Transform]):
+        self.transforms = transforms
+
+    def __call__(
+        self, inputs: Tensor, outputs_expected: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        for transforms in self.transforms:
+            inputs, outputs_expected = transforms(inputs, outputs_expected)
+        return inputs, outputs_expected
diff --git a/mu_map/logging.py b/mu_map/logging.py
index 29bf6215e745399b151392f0a45ba8b44c2467c0..fb43420249f05405719411dd880b2e8d45ab157c 100644
--- a/mu_map/logging.py
+++ b/mu_map/logging.py
@@ -1,15 +1,16 @@
 import argparse
-import datetime
+from dataclasses import dataclass
+from datetime import datetime
 import logging
 from logging import Formatter, getLogger, StreamHandler
 from logging.handlers import WatchedFileHandler
 import os
 import shutil
-from typing import Dict, Optional
-
+from typing import Dict, Optional, List
 
+date_format="%m/%d/%Y %I:%M:%S"
 FORMATTER = Formatter(
-    fmt="%(asctime)s - %(levelname)s: %(message)s", datefmt="%m/%d/%Y %I:%M:%S"
+    fmt="%(asctime)s - %(levelname)7s - %(message)s", datefmt=date_format
 )
 
 
@@ -35,7 +36,7 @@ def add_logging_args(parser: argparse.ArgumentParser, defaults: Dict[str, str]):
 
 
 def timestamp_filename(filename: str):
-    timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
+    timestamp = datetime.now().strftime("%Y-%m-%d-%H:%M:%S")
     basename, ext = os.path.splitext(filename)
     return f"{basename}_{timestamp}{ext}"
 
@@ -71,6 +72,34 @@ def get_logger_by_args(args):
     return get_logger(args.logfile, args.loglevel)
 
 
+@dataclass
+class LogLine:
+    time: datetime
+    loglevel: str
+    message: str
+
+    def __repr__(self):
+        return f"{self.time.strftime(date_format)} - {self.loglevel:>7} - {self.message}"
+
+
+def parse_line(logline):
+    _split = logline.strip().split("-")
+    assert len(_split) >= 3, f"A logged line should consists of a least three elements with the format [TIME - LOGLEVEL - MESSAGE] but got [{logline.strip()}]"
+
+    time_str = _split[0].strip()
+    time = datetime.strptime(time_str, date_format)
+
+    loglevel = _split[1].strip()
+
+    message = "-".join(_split[2:]).strip()
+    return LogLine(time=time, loglevel=loglevel, message=message)
+
+def parse_file(logfile: str) -> List[LogLine]:
+    with open(logfile, mode="r") as f:
+        lines = f.readlines()
+    lines = map(parse_line, lines)
+    return list(lines)
+
 if __name__ == "__main__":
     parser = argparse.ArgumentParser()
     add_logging_args(parser, defaults={"--loglevel": "DEBUG", "--logfile": "tmp.log"})
diff --git a/mu_map/models/unet.py b/mu_map/models/unet.py
index 9a4d5c875b4ecfa2678d358a3504db084dcac1a6..2d0f6a536784f242e3724860cfe29023af6e5fa0 100644
--- a/mu_map/models/unet.py
+++ b/mu_map/models/unet.py
@@ -1,5 +1,6 @@
 from typing import Optional, List
 
+import torch
 import torch.nn as nn
 
 
diff --git a/mu_map/test.py b/mu_map/test.py
index 1da35dfb4f733be69e8ff23f51d1fbc744d4b9bc..33c04b243422651967c4bd96bdff2b8ee945acaf 100644
--- a/mu_map/test.py
+++ b/mu_map/test.py
@@ -1,26 +1,143 @@
+import cv2 as cv
+import matplotlib.pyplot as plt
+import numpy as np
 import torch
 
-from .data.preprocessing import *
+from mu_map.dataset.default import MuMapDataset
+from mu_map.dataset.mock import MuMapMockDataset
+from mu_map.dataset.normalization import norm_max
+from mu_map.models.unet import UNet
+from mu_map.util import to_grayscale, COLOR_WHITE
 
-means = torch.full((10, 10, 10), 5.0)
-stds = torch.full((10, 10, 10), 10.0)
-x = torch.normal(means, stds)
+torch.set_grad_enabled(False)
 
-print(f"Before: mean={x.mean():.3f} std={x.std():.3f}")
+dataset = MuMapMockDataset("data/initial/")
 
-y = norm_gaussian(x)
-print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
-y = GaussianNorm()(x)
-print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
+model = UNet(in_channels=1, features=[8, 16])
+device = torch.device("cpu")
+weights = torch.load("tmp/10.pth", map_location=device)
+model.load_state_dict(weights)
+model = model.eval()
 
+recon, mu_map = dataset[0]
+recon = recon.unsqueeze(dim=0)
+recon = norm_max(recon)
 
-import cv2 as cv
-import numpy as np
+output = model(recon)
+output = output * 40206.0
+
+diff = ((mu_map - output) ** 2).mean()
+print(f"Diff: {diff:.3f}")
+
+output = output.squeeze().numpy()
+mu_map = mu_map.squeeze().numpy()
+
+assert output.shape[0] == mu_map.shape[0]
+
+wname = "Dataset"
+cv.namedWindow(wname, cv.WINDOW_NORMAL)
+cv.resizeWindow(wname, 1600, 900)
+space = np.full((1024, 10), 239, np.uint8)
+
+def to_display_image(image, _slice):
+    _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
+    _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
+    _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
+    _image = cv.putText(
+        _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
+    )
+    return _image
+
+def com(image1, image2, _slice):
+    image1 = to_display_image(image1, _slice)
+    image2 = to_display_image(image2, _slice)
+    space = np.full((image1.shape[0], 10), 239, np.uint8)
+    return np.hstack((image1, space, image2))
+
+
+i = 0
+while True:
+    x = com(output, mu_map, i)
+    cv.imshow(wname, x)
+    key = cv.waitKey(100)
+
+    if key == ord("q"):
+        break
+
+    i = (i + 1) % output.shape[0]
+
+
+
+
+# dataset = MuMapDataset("data/initial")
+
+# # print("                Recon ||                MuMap")
+# # print("     Min |      Max | Average ||     Min |      Max | Average")
+# r_max = []
+# r_avg = []
+
+# r_max_p = []
+# r_avg_p = []
+
+# r_avg_x = []
+
+# m_max = []
+# for recon, mu_map in dataset:
+    # r_max.append(recon.max())
+    # r_avg.append(recon.mean())
+
+    # recon_p = recon[:, :, 16:112, 16:112]
+    # r_max_p.append(recon_p.max())
+    # r_avg_p.append(recon_p.mean())
+
+    # r_avg_x.append(recon.sum() / (recon > 0.0).sum())
+    # # r_min = f"{recon.min():5.3f}"
+    # # r_max = f"{recon.max():5.3f}"
+    # # r_avg = f"{recon.mean():5.3f}"
+    # # m_min = f"{mu_map.min():5.3f}"
+    # # m_max = f"{mu_map.max():5.3f}"
+    # # m_avg = f"{mu_map.mean():5.3f}"
+    # # print(f"{r_min} | {r_max} |    {r_avg} || {m_min} | {m_max} |    {m_avg}")
+    # m_max.append(mu_map.max())
+    # # print(mu_map.max())
+# r_max = np.array(r_max)
+# r_avg = np.array(r_avg)
+
+# r_max_p = np.array(r_max_p)
+# r_avg_p = np.array(r_avg_p)
+
+# r_avg_x = np.array(r_avg_x)
+
+# m_max = np.array(m_max)
+# fig, ax = plt.subplots()
+# ax.scatter(r_max, m_max)
+
+# # fig, axs = plt.subplots(4, 3, figsize=(16, 12))
+# # axs[0, 0].hist(r_max)
+# # axs[0, 0].set_title("Max")
+# # axs[1, 0].hist(r_avg)
+# # axs[1, 0].set_title("Mean")
+# # axs[2, 0].hist(r_max / r_avg)
+# # axs[2, 0].set_title("Max / Mean")
+# # axs[3, 0].hist(recon.flatten())
+# # axs[3, 0].set_title("Example Reconstruction")
+
+# # axs[0, 1].hist(r_max_p)
+# # axs[0, 1].set_title("Max")
+# # axs[1, 1].hist(r_avg_p)
+# # axs[1, 1].set_title("Mean")
+# # axs[2, 1].hist(r_max_p / r_avg_p)
+# # axs[2, 1].set_title("Max / Mean")
+# # axs[3, 1].hist(recon_p.flatten())
+# # axs[3, 1].set_title("Example Reconstruction")
 
-x = np.zeros((512, 512), np.uint8)
-cv.imshow("X", x)
-key = cv.waitKey(0)
-while key != ord("q"):
-    print(key)
-    key = cv.waitKey(0)
+# # axs[0, 2].hist(r_max_p)
+# # axs[0, 2].set_title("Max")
+# # axs[1, 2].hist(r_avg_x)
+# # axs[1, 2].set_title("Mean")
+# # axs[2, 2].hist(r_max_p / r_avg_x)
+# # axs[2, 2].set_title("Max / Mean")
+# # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0)))
+# # axs[3, 2].set_title("Example Reconstruction")
 
+# plt.show()
diff --git a/mu_map/training/default.py b/mu_map/training/default.py
index 15330982e772dcb3298f0a99a6b8e77adebc0a97..adf89cefbd3a4a518893710f05f26769a0a81c92 100644
--- a/mu_map/training/default.py
+++ b/mu_map/training/default.py
@@ -1,44 +1,220 @@
+import os
+from typing import Dict
 
+import torch
 
-class Training():
+from mu_map.logging import get_logger
 
-    def __init__(self, epochs):
+
+class Training:
+    def __init__(
+        self,
+        model: torch.nn.Module,
+        data_loaders: Dict[str, torch.utils.data.DataLoader],
+        epochs: int,
+        device: torch.device,
+        lr: float,
+        lr_decay_factor: float,
+        lr_decay_epoch: int,
+        snapshot_dir: str,
+        snapshot_epoch: int,
+        logger=None,
+    ):
+        self.model = model
+        self.data_loaders = data_loaders
         self.epochs = epochs
+        self.device = device
+
+        self.lr = lr
+        self.lr_decay_factor = lr_decay_factor
+        self.lr_decay_epoch = lr_decay_epoch
+
+        self.snapshot_dir = snapshot_dir
+        self.snapshot_epoch = snapshot_epoch
+
+        self.logger = logger if logger is not None else get_logger()
+
+        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
+        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
+            self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
+        )
+        self.loss_func = torch.nn.MSELoss(reduction="mean")
+
 
     def run(self):
         for epoch in range(1, self.epochs + 1):
-            self.run_epoch(self.data_loader["train"], phase="train")
-            loss_training = self.run_epoch(self.data_loader["train"], phase="eval")
-            loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval")
+            logger.debug(
+                f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
+            )
+            self._run_epoch(self.data_loaders["train"], phase="train")
+
+            loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
+            logger.info(
+                f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}"
+            )
+            loss_validation = self._run_epoch(
+                self.data_loaders["validation"], phase="val"
+            )
+            logger.info(
+                f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}"
+            )
 
             # ToDo: log outputs and time
+            _previous = self.lr_scheduler.get_last_lr()[0]
             self.lr_scheduler.step()
+            logger.debug(
+                f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}"
+            )
 
-            if epoch % self.snapshot_epoch:
+            if epoch % self.snapshot_epoch == 0:
                 self.store_snapshot(epoch)
 
-            
+            logger.debug(
+                f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
+            )
 
-    def run_epoch(self, data_loader, phase):
+    def _run_epoch(self, data_loader, phase):
+        logger.debug(f"Run epoch in phase {phase}")
         self.model.train() if phase == "train" else self.model.eval()
 
         epoch_loss = 0
-        for inputs, labels in self.data_loader:
+        loss_updates = 0
+        for i, (inputs, labels) in enumerate(data_loader):
+            print(
+                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
+                end="\r",
+            )
             inputs = inputs.to(self.device)
             labels = labels.to(self.device)
 
             self.optimizer.zero_grad()
             with torch.set_grad_enabled(phase == "train"):
                 outputs = self.model(inputs)
-                loss = self.loss(outputs, labels)
+                loss = self.loss_func(outputs, labels)
 
                 if phase == "train":
                     loss.backward()
-                    optimizer.step()
-
-            epoch_loss += loss.item() / inputs.size[0]
-        return epoch_loss
+                    self.optimizer.step()
 
+            epoch_loss += loss.item()
+            loss_updates += 1
+        return epoch_loss / loss_updates
 
     def store_snapshot(self, epoch):
-        pass
+        snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
+        snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
+        logger.debug(f"Store snapshot at {snapshot_file}")
+        torch.save(self.model.state_dict(), snapshot_file)
+
+
+if __name__ == "__main__":
+    import argparse
+
+    from mu_map.dataset.mock import MuMapMockDataset
+    from mu_map.logging import add_logging_args, get_logger_by_args
+    from mu_map.models.unet import UNet
+
+    parser = argparse.ArgumentParser(
+        description="Train a UNet model to predict μ-maps from reconstructed scatter images",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Model Args
+    parser.add_argument(
+        "--features",
+        type=int,
+        nargs="+",
+        default=[8, 16],
+        help="number of features in the layers of the UNet structure",
+    )
+
+    # Dataset Args
+    # parser.add_argument("--features", type=int, nargs="+", default=[8, 16], help="number of features in the layers of the UNet structure")
+
+    # Training Args
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="train_data",
+        help="directory in which results (snapshots and logs) of this training are saved",
+    )
+    parser.add_argument(
+        "--epochs",
+        type=int,
+        default=10,
+        help="the number of epochs for which the model is trained",
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda:0" if torch.cuda.is_available() else "cpu",
+        help="the device (cpu or gpu) with which the training is performed",
+    )
+    parser.add_argument(
+        "--lr", type=float, default=0.1, help="the initial learning rate for training"
+    )
+    parser.add_argument(
+        "--lr_decay_factor",
+        type=float,
+        default=0.99,
+        help="decay factor for the learning rate",
+    )
+    parser.add_argument(
+        "--lr_decay_epoch",
+        type=int,
+        default=1,
+        help="frequency in epochs at which the learning rate is decayed",
+    )
+    parser.add_argument(
+        "--snapshot_dir",
+        type=str,
+        default="snapshots",
+        help="directory under --output_dir where snapshots are stored",
+    )
+    parser.add_argument(
+        "--snapshot_epoch",
+        type=int,
+        default=10,
+        help="frequency in epochs at which snapshots are stored",
+    )
+
+    # Logging Args
+    add_logging_args(parser, defaults={"--logfile": "train.log"})
+
+    args = parser.parse_args()
+
+    if not os.path.exists(args.output_dir):
+        os.mkdir(args.output_dir)
+
+    args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
+    if not os.path.exists(args.snapshot_dir):
+        os.mkdir(args.snapshot_dir)
+
+    args.logfile = os.path.join(args.output_dir, args.logfile)
+
+    device = torch.device(args.device)
+    logger = get_logger_by_args(args)
+
+    model = UNet(in_channels=1, features=args.features)
+    dataset = MuMapMockDataset(logger=logger)
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
+    )
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
+    )
+    data_loaders = {"train": data_loader_train, "validation": data_loader_val}
+
+    training = Training(
+        model=model,
+        data_loaders=data_loaders,
+        epochs=args.epochs,
+        device=device,
+        lr=args.lr,
+        lr_decay_factor=args.lr_decay_factor,
+        lr_decay_epoch=args.lr_decay_epoch,
+        snapshot_dir=args.snapshot_dir,
+        snapshot_epoch=args.snapshot_epoch,
+        logger=logger,
+    )
+    training.run()
diff --git a/mu_map/util.py b/mu_map/util.py
index 259a7bbb970f2b2aa8132355f8eb367664d8adb4..914793b49953f190912a5456a0e9207945711c6b 100644
--- a/mu_map/util.py
+++ b/mu_map/util.py
@@ -23,6 +23,9 @@ def to_grayscale(
     if max_val is None:
         max_val = img.max()
 
+    if (max_val - min_val) == 0:
+        return np.zeros(img.shape, np.uint8)
+
     _img = (img - min_val) / (max_val - min_val)
     _img = (_img * 255).astype(np.uint8)
     return _img
@@ -36,4 +39,4 @@ def grayscale_to_rgb(img: np.ndarray):
     :return: the image in rgb
     """
     assert img.ndim == 2, f"grascale image has more than 2 dimensions {img.shape}"
-    return img.repeat(3).reshape((3, *img.shape))
+    return img.repeat(3).reshape((*img.shape, 3))
diff --git a/mu_map/vis/loss_curve.py b/mu_map/vis/loss_curve.py
new file mode 100644
index 0000000000000000000000000000000000000000..336b1ca949066b82662ff2628f64c883b2d973b8
--- /dev/null
+++ b/mu_map/vis/loss_curve.py
@@ -0,0 +1,66 @@
+import argparse
+
+import matplotlib.pyplot as plt
+import numpy as np
+
+from mu_map.logging import parse_file
+
+SIZE_DEFAULT = 12
+plt.rc("font", family="Roboto")  # controls default font
+plt.rc("font", weight="normal")  # controls default font
+plt.rc("font", size=SIZE_DEFAULT)  # controls default text sizes
+plt.rc("axes", titlesize=18)  # fontsize of the axes title
+
+# https://colorbrewer2.org/#type=diverging&scheme=RdBu&n=3lk
+COLORS = ["#ef8a62", "#67a9cf"]
+
+parser = argparse.ArgumentParser(description="TODO", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+parser.add_argument("logfile", type=str, help="TODO")
+parser.add_argument("--normalize", action="store_true", help="TODO")
+args = parser.parse_args()
+
+logs = parse_file(args.logfile)
+logs = list(filter(lambda logline: logline.loglevel == "INFO", logs))
+
+def parse_loss(logs, phase):
+    _logs = map(lambda logline: logline.message, logs)
+    _logs = filter(lambda log: phase in log, _logs)
+    _logs = filter(lambda log: "Loss" in log, _logs)
+    _logs = list(_logs)
+
+    losses = map(lambda log: log.split("-")[-1].strip(), _logs)
+    losses = map(lambda log: log.split(":")[-1].strip(), losses)
+    losses = map(float, losses)
+
+    epochs = map(lambda log: log.split("-")[0].strip(), _logs)
+    epochs = list(epochs)
+    epochs = map(lambda log: log.split(" ")[-1], epochs)
+    epochs = map(lambda log: log.split("/")[0], epochs)
+    epochs = map(int, epochs)
+
+    return np.array(list(epochs)), np.array(list(losses))
+
+phases = ["TRAIN", "VAL"]
+labels = ["Training", "Validation"]
+
+fig, ax = plt.subplots()
+for phase, label, color in zip(phases, labels, COLORS):
+    epochs, loss = parse_loss(logs, phase)
+
+    if args.normalize:
+        loss = loss / loss.max()
+
+    ax.plot(epochs, loss, label=label, color=color)
+    ax.scatter(epochs, loss, s=15, color=color)
+
+ax.spines["left"].set_visible(False)
+ax.spines["right"].set_visible(False)
+ax.spines["top"].set_visible(False)
+
+ax.grid(axis="y", alpha=0.7)
+ax.legend()
+ax.set_xlabel("Epoch")
+ax.set_ylabel("Loss")
+plt.tight_layout()
+plt.show()
+