diff --git a/mu_map/file/interfile.py b/mu_map/file/interfile.py
index b7202e8e199e1078b8e9e63475915ef389b5e9ea..85cef0ad767225bc3a2e84d31d4a2215caccd051 100644
--- a/mu_map/file/interfile.py
+++ b/mu_map/file/interfile.py
@@ -8,6 +8,9 @@ Interfile = Tuple[Dict[str, str], np.ndarray]
 
 @dataclass
 class _InterfileKeys:
+    """
+    Data class defining keys for an Interfile header.
+    """
     placeholder: str = "{_}"
 
     _dim: str = f"matrix size [{placeholder}]"
@@ -18,6 +21,9 @@ class _InterfileKeys:
     bytes_per_pixel: str = "number of bytes per pixel"
     number_format: str = "number format"
 
+    patient_orientation: str = "patient orientation"
+    patient_rotation: str = "supine"
+
     def dim(self, index: int) -> str:
         return self._dim.replace(self.placeholder, str(index))
 
@@ -26,22 +32,6 @@ class _InterfileKeys:
 
 InterfileKeys = _InterfileKeys()
 
-"""
-Several keys defined in INTERFILE headers.
-"""
-KEY_DIM_1 = "matrix size [1]"
-KEY_DIM_2 = "matrix size [2]"
-KEY_DIM_3 = "matrix size [3]"
-KEY_SPACING_1 = "scaling factor (mm/pixel) [1]"
-KEY_SPACING_2 = "scaling factor (mm/pixel) [2]"
-KEY_SPACING_3 = "scaling factor (mm/pixel) [2]"
-KEY_NPROJECTIONS = "number of projections"
-
-KEY_DATA_FILE = "name of data file"
-
-KEY_BYTES_PER_PIXEL = "number of bytes per pixel"
-KEY_NUMBER_FORMAT = "number format"
-
 
 """
 A template of an INTERFILE header.
@@ -55,6 +45,8 @@ GENERAL DATA :=
 GENERAL IMAGE DATA :=
 type of data := Tomographic
 imagedata byte order := LITTLEENDIAN
+patient orientation := feet_in
+patient rotation := supine
 isotope name := ^99m^Technetium
 SPECT STUDY (General) :=
 process status := Reconstructed
@@ -131,15 +123,15 @@ def load_interfile(filename: str) -> Tuple[Dict[str, str], np.ndarray]:
     """
     header = parse_interfile_header(filename)
 
-    dim_x = int(header[KEY_DIM_1])
-    dim_y = int(header[KEY_DIM_2])
-    dim_z = int(header[KEY_DIM_3]) if KEY_DIM_3 in header else int(header[KEY_NPROJECTIONS])
+    dim_x = int(header[InterfileKeys.dim(1)])
+    dim_y = int(header[InterfileKeys.dim(2)])
+    dim_z = int(header[InterfileKeys.dim(3)]) if InterfileKeys.dim(3) in header else int(header[InterfileKeys.n_projections])
 
-    bytes_per_pixel = int(header[KEY_BYTES_PER_PIXEL])
-    num_format = header[KEY_NUMBER_FORMAT]
+    bytes_per_pixel = int(header[InterfileKeys.bytes_per_pixel])
+    num_format = header[InterfileKeys.number_format]
     dtype = type_by_format(num_format, bytes_per_pixel)
 
-    data_file = os.path.join(os.path.dirname(filename), header[KEY_DATA_FILE])
+    data_file = os.path.join(os.path.dirname(filename), header[InterfileKeys.data_file])
     with open(data_file, mode="rb") as f:
         image = np.frombuffer(f.read(), dtype)
     image = image.reshape((dim_z, dim_y, dim_x))
@@ -162,10 +154,10 @@ def write_interfile(filename: str, header: Dict[str, str], image: np.ndarray):
     filename_data = f"{filename}.v"
     filename_header = f"{filename}.hv"
 
-    header[KEY_DATA_FILE] = os.path.basename(filename_data)
-    header[KEY_DIM_3] = str(image.shape[0])
-    header[KEY_DIM_2] = str(image.shape[1])
-    header[KEY_DIM_1] = str(image.shape[2])
+    header[InterfileKeys.data_file] = os.path.basename(filename_data)
+    header[InterfileKeys.dim(3)] = str(image.shape[0])
+    header[InterfileKeys.dim(2)] = str(image.shape[1])
+    header[InterfileKeys.dim(1)] = str(image.shape[2])
 
     image = image.astype(np.float32)
 
diff --git a/mu_map/recon/osem.py b/mu_map/recon/osem.py
index d4abd98a74695ce847e85a3c07e7462e4080cf1c..ef47a18ec186cc4af15ad44de6a7bb8df14d789e 100644
--- a/mu_map/recon/osem.py
+++ b/mu_map/recon/osem.py
@@ -10,10 +10,7 @@ from mu_map.file.interfile import (
     parse_interfile_header_str,
     load_interfile,
     write_interfile,
-    KEY_DIM_1,
-    KEY_DIM_2,
-    KEY_SPACING_1,
-    KEY_SPACING_2,
+    InterfileKeys,
     TEMPLATE_HEADER_IMAGE,
 )
 from mu_map.recon.filter import GaussianFilter
@@ -77,13 +74,13 @@ def uniform_estimate(projection: Tuple[Dict[str, str], np.ndarray]):
         (image_proj.shape[1], image_proj.shape[2], image_proj.shape[2]), np.float32
     )
 
-    offset = -0.5 * image_proj.shape[2] * float(header_proj[KEY_SPACING_1])
+    offset = -0.5 * image_proj.shape[2] * float(header_proj[InterfileKeys.spacing(1)])
     header = TEMPLATE_HEADER_IMAGE.replace("{ROWS}", str(image.shape[2]))
     header = header.replace("{COLUMNS}", str(image.shape[1]))
     header = header.replace("{SLICES}", str(image.shape[0]))
-    header = header.replace("{SPACING_X}", header_proj[KEY_SPACING_1])
-    header = header.replace("{SPACING_Y}", header_proj[KEY_SPACING_1])
-    header = header.replace("{SPACING_Z}", header_proj[KEY_SPACING_2])
+    header = header.replace("{SPACING_X}", header_proj[InterfileKeys.spacing(1)])
+    header = header.replace("{SPACING_Y}", header_proj[InterfileKeys.spacing(1)])
+    header = header.replace("{SPACING_Z}", header_proj[InterfileKeys.spacing(1)])
     header = header.replace("{OFFSET_X}", f"{offset:.4f}")
     header = header.replace("{OFFSET_Y}", f"{offset:.4f}")
     header = parse_interfile_header_str(header)
diff --git a/mu_map/recon/project.py b/mu_map/recon/project.py
index 5aaaa900a69717e7fd6bb9c206a46f212aab0e46..05cce8e703655bc36d27d6d006cc4d91ced820e3 100644
--- a/mu_map/recon/project.py
+++ b/mu_map/recon/project.py
@@ -9,8 +9,7 @@ import stir
 from mu_map.file.interfile import (
     load_interfile,
     write_interfile,
-    KEY_SPACING_1,
-    KEY_SPACING_3,
+    InterfileKeys,
 )
 
 
@@ -28,6 +27,8 @@ type of data := Tomographic
 imagedata byte order := LITTLEENDIAN
 number format := float
 number of bytes per pixel := 4
+patient orientation := feet_in
+patient rotation := supine
 
 SPECT STUDY (General) := 
 ;matrix axis label [2] := axial coordinate
@@ -88,8 +89,8 @@ def forward_project(
     header_proj = TEMPLATE_HEADER_PROJ.replace("{DATA_FILE}", filename_proj_data)
     header_proj = header_proj.replace("{SLICES}", str(n_slices))
     header_proj = header_proj.replace("{BINS}", str(n_bins))
-    header_proj = header_proj.replace("{SPACING_SLICES}", recon_header[KEY_SPACING_3])
-    header_proj = header_proj.replace("{SPACING_BINS}", recon_header[KEY_SPACING_1])
+    header_proj = header_proj.replace("{SPACING_SLICES}", recon_header[InterfileKeys.spacing(3)])
+    header_proj = header_proj.replace("{SPACING_BINS}", recon_header[InterfileKeys.spacing(1)])
     header_proj = header_proj.replace("{N_PROJECTIONS}", str(n_projections))
     header_proj = header_proj.replace("{ROTATION}", str(rotation))
     header_proj = header_proj.replace("{START_ANGLE}", str(start_angle))
diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py
index d40ea9982fa85c01f1d384abcc65af7bd9e35078..3fdb79cf33d9e02918281ab874f08c51e89e653b 100644
--- a/mu_map/training/distance.py
+++ b/mu_map/training/distance.py
@@ -53,7 +53,6 @@ if __name__ == "__main__":
         MaxNormTransform,
         GaussianNormTransform,
     )
-    from mu_map.dataset.transform import ScaleTransform
     from mu_map.logging import add_logging_args, get_logger_by_args
     from mu_map.models.unet import UNet
 
diff --git a/mu_map/vis/slices.py b/mu_map/vis/slices.py
index 15c89159bfad693f75ac96491784aaeffbfcb7ab..ce36eaa715551574fab2fb73d6974a387ce08eb1 100644
--- a/mu_map/vis/slices.py
+++ b/mu_map/vis/slices.py
@@ -2,6 +2,7 @@ from typing import List
 
 import numpy as np
 
+from mu_map.dataset.util import align_images
 from mu_map.file.dicom import load_dcm_img
 from mu_map.file.interfile import load_interfile_img
 
@@ -27,20 +28,23 @@ def load_image(filename: str) -> np.ndarray:
     raise ValueError(f"Could not load {filename} as DICOM or INTERFILE image!")
 
 
-def join_images(images: List[np.ndarray], separator: np.ndarray) -> List[np.ndarray]:
+def join_images(
+    images: List[np.ndarray], separator: np.ndarray, vertical: bool = False
+) -> List[np.ndarray]:
     """
     Create a new image by joining all input images along a separator image.
-    Note that the images are joined horizontally. Thus, their shape along
-    the first axis must be equal.
+    If joining horizontally, their shape must be equal along their first axis.
+    If joining vertically, their shape must be equal along their second axis.
 
     :param images: a list of images to join
     :param separator: a separator image inserted between all images
+    :param vertical: join images vertically instead of horizontally
     :return: a new image joined as described above
     """
     res = []
     for image in images:
         res += [image, separator]
-    return np.hstack(res[:-1])
+    return np.vstack(res[:-1]) if vertical else np.hstack(res[:-1])
 
 
 if __name__ == "__main__":
@@ -51,14 +55,26 @@ if __name__ == "__main__":
     from mu_map.util import to_grayscale, COLOR_WHITE
 
     parser = argparse.ArgumentParser(
-        description="Visualize 3D Volumes as a video of their slices"
+        description="Visualize 3D Volumes as a video of their slices",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
     parser.add_argument("images", type=str, nargs="+", help="the images to visualize")
     parser.add_argument(
         "--resize", type=int, default=512, help="resize images to this size"
     )
     parser.add_argument(
-        "--fps", type=int, default=25, help="frames (slices) to show per second"
+        "--fps", type=int, default=10, help="frames (slices) to show per second"
+    )
+    parser.add_argument("--align", action="store_true", help="center align images")
+    parser.add_argument(
+        "--shared_range",
+        action="store_true",
+        help="normalize all images to the same value range",
+    )
+    parser.add_argument(
+        "--vertical",
+        action="store_true",
+        help="join images vertically instead of horizontally",
     )
     parser.add_argument(
         "--window_name", type=str, default="Slices", help="name of the displayed window"
@@ -66,9 +82,33 @@ if __name__ == "__main__":
     args = parser.parse_args()
 
     images = list(map(load_image, args.images))
-    scales = list(map(lambda image: args.resize / image.shape[1], images))
+
+    if args.align:
+        image_with_least_slices = sorted(images, key=lambda image: image.shape[0])[0]
+        images = list(
+            map(lambda image: align_images(image, image_with_least_slices)[0], images)
+        )
+
+    scales = list(
+        map(
+            lambda image: args.resize / image.shape[2]
+            if args.vertical
+            else args.resize / image.shape[1],
+            images,
+        )
+    )
     slices = [0] * len(images)
-    space = np.full((args.resize, 10), 239, np.uint8)
+    space = (
+        np.full((10, args.resize), 239, np.uint8)
+        if args.vertical
+        else np.full((args.resize, 10), 239, np.uint8)
+    )
+
+    min_vals = list(map(lambda image: image.min(), images))
+    max_vals = list(map(lambda image: image.max(), images))
+    if args.shared_range:
+        min_vals = [min(min_vals)] * len(images)
+        max_vals = [max(max_vals)] * len(images)
 
     cv.namedWindow(args.window_name, cv.WINDOW_NORMAL)
     cv.resizeWindow(args.window_name, 1600, 900)
@@ -77,16 +117,24 @@ if __name__ == "__main__":
     current_timeout = timeout
     while True:
         _images = []
-        for i, (image, _slice, scale) in enumerate(zip(images, slices, scales)):
-            _image = to_grayscale(image[_slice], max_val=image.max(), min_val=image.min())
+        for i, (image, _slice, scale, min_val, max_val) in enumerate(
+            zip(images, slices, scales, min_vals, max_vals)
+        ):
+            _image = to_grayscale(image[_slice], min_val=min_val, max_val=max_val)
             _image = cv.resize(_image, None, fx=scale, fy=scale)
             _image = cv.putText(
-                _image, str(_slice + 1), (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
+                _image,
+                str(_slice + 1),
+                (0, 30),
+                cv.FONT_HERSHEY_SIMPLEX,
+                1,
+                COLOR_WHITE,
+                3,
             )
             _images.append(_image)
 
             slices[i] = (_slice + 1) % image.shape[0]
-        image = join_images(_images, space)
+        image = join_images(_images, space, vertical=args.vertical)
 
         cv.imshow(args.window_name, image)
         key = cv.waitKey(current_timeout)