diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py
index 7d35bc44de5aa45659fb8297c79f20ada97e4dd1..5b98ff3c388a93b644f29002fdd89963e62a9814 100644
--- a/mu_map/dataset/default.py
+++ b/mu_map/dataset/default.py
@@ -132,55 +132,9 @@ class MuMapDataset(Dataset):
 
 __all__ = [MuMapDataset.__name__]
 
-if __name__ == "__main__":
-    import argparse
-
-    from mu_map.dataset.normalization import (
-        GaussianNormTransform,
-        MaxNormTransform,
-        MeanNormTransform,
-    )
-    from mu_map.logging import add_logging_args, get_logger_by_args
+def main(dataset):
     from mu_map.util import to_grayscale, COLOR_WHITE
 
-    parser = argparse.ArgumentParser(
-        description="Visualize the images of a MuMapDataset",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-    parser.add_argument(
-        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
-    )
-    parser.add_argument(
-        "--unaligned",
-        action="store_true",
-        help="do not perform center alignment of reconstruction an mu-map slices",
-    )
-    parser.add_argument(
-        "--show_bed",
-        action="store_true",
-        help="do not remove the bed contour from the mu map",
-    )
-    parser.add_argument(
-        "--full_mu_map",
-        action="store_true",
-        help="do not remove broken slices of the mu map",
-    )
-    add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
-    args = parser.parse_args()
-
-    align = not args.unaligned
-    discard_mu_map_slices = not args.full_mu_map
-    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
-    logger = get_logger_by_args(args)
-
-    dataset = MuMapDataset(
-        args.dataset_dir,
-        align=align,
-        discard_mu_map_slices=discard_mu_map_slices,
-        bed_contours_file=bed_contours_file,
-        logger=logger,
-    )
-
     wname = "Dataset"
     cv.namedWindow(wname, cv.WINDOW_NORMAL)
     cv.resizeWindow(wname, 1600, 900)
@@ -238,3 +192,49 @@ if __name__ == "__main__":
             elif key == ord("s"):
                 cv.imwrite(f"{running:03d}.png", to_show)
                 running += 1
+
+
+if __name__ == "__main__":
+    import argparse
+
+    from mu_map.logging import add_logging_args, get_logger_by_args
+
+    parser = argparse.ArgumentParser(
+        description="Visualize the images of a MuMapDataset",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
+    )
+    parser.add_argument(
+        "--unaligned",
+        action="store_true",
+        help="do not perform center alignment of reconstruction an mu-map slices",
+    )
+    parser.add_argument(
+        "--show_bed",
+        action="store_true",
+        help="do not remove the bed contour from the mu map",
+    )
+    parser.add_argument(
+        "--full_mu_map",
+        action="store_true",
+        help="do not remove broken slices of the mu map",
+    )
+    add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
+    args = parser.parse_args()
+
+    align = not args.unaligned
+    discard_mu_map_slices = not args.full_mu_map
+    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
+    logger = get_logger_by_args(args)
+
+    dataset = MuMapDataset(
+        args.dataset_dir,
+        align=align,
+        discard_mu_map_slices=discard_mu_map_slices,
+        bed_contours_file=bed_contours_file,
+        logger=logger,
+    )
+    main(dataset)
+
diff --git a/mu_map/dataset/mock.py b/mu_map/dataset/mock.py
index de564b0f9b224cc334febc338a46c01af907dd82..9f33b3c900f0a85b053a2e7efabedb04de4e783b 100644
--- a/mu_map/dataset/mock.py
+++ b/mu_map/dataset/mock.py
@@ -1,9 +1,10 @@
 from mu_map.dataset.default import MuMapDataset
+from mu_map.dataset.normalization import MaxNormTransform
 
 
 class MuMapMockDataset(MuMapDataset):
-    def __init__(self, dataset_dir: str = "data/initial/", num_images: int = 16):
-        super().__init__(dataset_dir=dataset_dir)
+    def __init__(self, dataset_dir: str = "data/initial/", num_images: int = 16, logger=None):
+        super().__init__(dataset_dir=dataset_dir, transform_normalization=MaxNormTransform(), logger=logger)
         self.len = num_images
 
     def __getitem__(self, index: int):
@@ -20,63 +21,7 @@ if __name__ == "__main__":
     import cv2 as cv
     import numpy as np
 
-    from mu_map.util import to_grayscale, COLOR_WHITE
+    from mu_map.dataset.default import main
 
     dataset = MuMapMockDataset()
-
-    wname = "Dataset"
-    cv.namedWindow(wname, cv.WINDOW_NORMAL)
-    cv.resizeWindow(wname, 1600, 900)
-    space = np.full((1024, 10), 239, np.uint8)
-
-    timeout = 100
-
-    def to_display_image(image, _slice):
-        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
-        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
-        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
-        _image = cv.putText(
-            _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
-        )
-        return _image
-
-    def combine_images(images, slices):
-        image_1 = to_display_image(images[0], slices[0])
-        image_2 = to_display_image(images[1], slices[1])
-        space = np.full((image_1.shape[0], 10), 239, np.uint8)
-        return np.hstack((image_1, space, image_2))
-
-    for i in range(len(dataset)):
-        ir = 0
-        im = 0
-
-        recon, mu_map = dataset[i]
-        print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")
-
-        cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
-        key = cv.waitKey(timeout)
-
-        running = 0
-        while True:
-            ir = (ir + 1) % recon.shape[0]
-            im = (im + 1) % mu_map.shape[0]
-
-            to_show = combine_images((recon, mu_map), (ir, im))
-            cv.imshow(wname, to_show)
-
-            key = cv.waitKey(timeout)
-
-            if key == ord("n"):
-                break
-            elif key == ord("q"):
-                exit(0)
-            elif key == ord("p"):
-                timeout = 0 if timeout > 0 else 100
-            elif key == 83:  # right arrow key
-                continue
-            elif key == 81:  # left arrow key
-                ir = max(ir - 2, 0)
-                im = max(im - 2, 0)
-            elif key == ord("s"):
-                cv.imwrite(f"{running:03d}.png", to_show)
-                running += 1
+    main(dataset)