diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py
index 94fa52aaca078dd6fbd61736d1c626d3bd2d036d..1568919a73905899e55062221701388ec0b71eb1 100644
--- a/mu_map/dataset/default.py
+++ b/mu_map/dataset/default.py
@@ -1,5 +1,5 @@
 import os
-from typing import Optional
+from typing import Optional, Tuple
 
 import cv2 as cv
 import pandas as pd
@@ -9,48 +9,18 @@ import torch
 from torch.utils.data import Dataset
 
 from mu_map.data.prepare import headers
-from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
+from mu_map.data.remove_bed import (
+    DEFAULT_BED_CONTOURS_FILENAME,
+    load_contours,
+    remove_bed,
+)
 from mu_map.data.review_mu_map import discard_slices
 from mu_map.data.split import split_csv
 from mu_map.dataset.transform import Transform
+from mu_map.dataset.util import align_images, load_dcm_img
 from mu_map.logging import get_logger
 
 
-"""
-Since DICOM images only allow images stored in short integer format,
-the Siemens scanner software multiplies values by a factor before storing
-so that no precision is lost.
-The scale can be found in this private DICOM tag.
-"""
-DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038
-
-
-def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
-    """
-    Align one image to another on the first axis (z-axis).
-    It is assumed that the second image has less slices than the first.
-    Then, the first image is shortened in a way that the centers of both images lie on top of each other.
-
-    :param image_1: the image to be aligned
-    :param image_2: the image to which image_1 is aligned
-    :return: the aligned image_1
-    """
-    assert (
-        image_1.shape[0] > image_2.shape[0]
-    ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"
-
-    # central slice of image 2
-    c_2 = image_2.shape[0] // 2
-    # image to the left and right of the center
-    left = c_2
-    right = image_2.shape[0] - c_2
-
-    # central slice of image 1
-    c_1 = image_1.shape[0] // 2
-    # select center and same amount to the left/right as image_2
-    return image_1[(c_1 - left) : (c_1 + right)]
-
-
 class MuMapDataset(Dataset):
     def __init__(
         self,
@@ -102,48 +72,39 @@ class MuMapDataset(Dataset):
 
         self.reconstructions = {}
         self.mu_maps = {}
-        self.pre_load_images()
-
-    def pre_load_images(self):
-        self.logger.debug("Pre-loading images ...")
-        for i in range(len(self.table)):
-            row = self.table.iloc[i]
-            _id = row[headers.id]
-
-            mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
-            mu_map = pydicom.dcmread(mu_map_file)
-            mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value
-            if self.discard_mu_map_slices:
-                mu_map = discard_slices(row, mu_map)
-            if self.bed_contours:
-                if _id in self.bed_contours:
-                    bed_contour = self.bed_contours[_id]
-                    for i in range(mu_map.shape[0]):
-                        mu_map[i] = cv.drawContours(
-                            mu_map[i], [bed_contour], -1, 0.0, -1
-                        )
-                else:
-                    logger.warning(f"Could not find bed contour for id {_id}")
-
-            recon_file = os.path.join(self.dir_images, row[self.header_recon])
-            recon = pydicom.dcmread(recon_file)
-            recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value
-            if self.align:
-                recon = align_images(recon, mu_map)
-
-            mu_map = mu_map.astype(np.float32)
-            mu_map = torch.from_numpy(mu_map)
-            mu_map = mu_map.unsqueeze(dim=0)
-
-            recon = recon.astype(np.float32)
-            recon = torch.from_numpy(recon)
-            recon = recon.unsqueeze(dim=0)
-
-            recon, mu_map = self.transform_normalization(recon, mu_map)
-
-            self.mu_maps[_id] = mu_map
-            self.reconstructions[_id] = recon
-        self.logger.debug("Pre-loading images done!")
+
+    def load_image(self, _id: int):
+        row = self.table[self.table[headers.id] == _id].iloc[0]
+        _id = row[headers.id]
+
+        mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
+        mu_map = load_dcm_img(mu_map_file)
+        if self.discard_mu_map_slices:
+            mu_map = discard_slices(row, mu_map)
+        if self.bed_contours:
+            if _id in self.bed_contours:
+                bed_contour = self.bed_contours[_id]
+                mu_map = remove_bed(mu_map, bed_contour)
+            else:
+                logger.warning(f"Could not find bed contour for id {_id}")
+
+        recon_file = os.path.join(self.dir_images, row[self.header_recon])
+        recon = load_dcm_img(recon_file)
+        if self.align:
+            recon, mu_map = align_images(recon, mu_map)
+
+        mu_map = mu_map.astype(np.float32)
+        mu_map = torch.from_numpy(mu_map)
+        mu_map = mu_map.unsqueeze(dim=0)
+
+        recon = recon.astype(np.float32)
+        recon = torch.from_numpy(recon)
+        recon = recon.unsqueeze(dim=0)
+
+        recon, mu_map = self.transform_normalization(recon, mu_map)
+
+        self.mu_maps[_id] = mu_map
+        self.reconstructions[_id] = recon
 
     def __getitem__(self, index: int):
         row = self.table.iloc[index]
@@ -151,6 +112,9 @@ class MuMapDataset(Dataset):
         return self.getitem_by_id(_id)
 
     def getitem_by_id(self, _id: int):
+        if _id not in self.reconstructions:
+            self.load_image(_id)
+
         recon = self.reconstructions[_id]
         mu_map = self.mu_maps[_id]
 
@@ -165,7 +129,7 @@ class MuMapDataset(Dataset):
 __all__ = [MuMapDataset.__name__]
 
 
-def main(dataset):
+def main(dataset, ids, paused=False):
     from mu_map.util import to_grayscale, COLOR_WHITE
 
     wname = "Dataset"
@@ -173,7 +137,10 @@ def main(dataset):
     cv.resizeWindow(wname, 1600, 900)
     space = np.full((1024, 10), 239, np.uint8)
 
-    timeout = 100
+    TIMEOUT_PAUSED = 0
+    TIMEOUT_RUNNING = 1000 // 15
+
+    timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING
 
     def to_display_image(image, _slice):
         _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
@@ -187,8 +154,16 @@ def main(dataset):
     def combine_images(images, slices):
         image_1 = to_display_image(images[0], slices[0])
         image_2 = to_display_image(images[1], slices[1])
-        space = np.full((image_1.shape[0], 10), 239, np.uint8)
-        return np.hstack((image_1, space, image_2))
+
+        image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
+        image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))
+
+        image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
+        image_3_1 = image_2.copy()
+        image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)
+
+        space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
+        return np.hstack((image_1, space, image_3, space, image_2))
 
     for i in range(len(dataset)):
         ir = 0
@@ -197,6 +172,9 @@ def main(dataset):
         row = dataset.table.iloc[i]
         _id = row[headers.id]
 
+        if ids is not None and _id not in ids:
+            continue
+
         recon, mu_map = dataset[i]
         recon = recon.squeeze().numpy()
         mu_map = mu_map.squeeze().numpy()
@@ -219,11 +197,18 @@ def main(dataset):
             elif key == ord("q"):
                 exit(0)
             elif key == ord("p"):
-                timeout = 0 if timeout > 0 else 100
+                timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING
+            elif key == 82:  # up arrow key
+                ir = ir - 1
+                continue
             elif key == 83:  # right arrow key
+                im = im - 1
                 continue
             elif key == 81:  # left arrow key
+                im = im - 1
                 ir = max(ir - 2, 0)
+            elif key == 84:  # down arrow key
+                ir = ir - 1
                 im = max(im - 2, 0)
             elif key == ord("s"):
                 cv.imwrite(f"{running:03d}.png", to_show)
@@ -233,6 +218,7 @@ def main(dataset):
 if __name__ == "__main__":
     import argparse
 
+    from mu_map.dataset.transform import PadCropTranform
     from mu_map.logging import add_logging_args, get_logger_by_args
 
     parser = argparse.ArgumentParser(
@@ -263,6 +249,22 @@ if __name__ == "__main__":
         action="store_true",
         help="do not remove broken slices of the mu map",
     )
+    parser.add_argument(
+        "--ids",
+        type=int,
+        nargs="*",
+        help="only display certain ids",
+    )
+    parser.add_argument(
+        "--paused",
+        action="store_true",
+        help="start in paused mode",
+    )
+    parser.add_argument(
+        "--pad_crop",
+        type=int,
+        help="pad crop images to this size",
+    )
     add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
     args = parser.parse_args()
 
@@ -271,12 +273,17 @@ if __name__ == "__main__":
     bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
     logger = get_logger_by_args(args)
 
+    transform_normalization = (
+        PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
+    )
+
     dataset = MuMapDataset(
         args.dataset_dir,
         align=align,
         discard_mu_map_slices=discard_mu_map_slices,
         bed_contours_file=bed_contours_file,
         split_name=args.split,
+        transform_normalization=transform_normalization,
         logger=logger,
     )
-    main(dataset)
+    main(dataset, args.ids, paused=args.paused)