diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py
index 9aade9409018913fedefb1093bc6e4f383ee8497..42a975aa23ea3d0e7374eed652494e29fced18c9 100644
--- a/mu_map/dataset/patches.py
+++ b/mu_map/dataset/patches.py
@@ -8,11 +8,27 @@ from mu_map.dataset.default import MuMapDataset
 
 
 class MuMapPatchDataset(MuMapDataset):
+    """
+    A wrapper around the MuMapDataset that computes patches for each reconstruction-μ-map pair.
+
+    :param dataset_dir: the directory containing the dataset - is passed to MuMapDataset
+    :param patches_per_image: the amount of patches to randomly generate for each image
+    :param patch_size: the size of patches in x- and y-direction
+    :param patch_size_z: the size of patches in z-direction --- it is a separate parameter because
+                         images are typically shorter in this direction
+    :param patch_offset: offset of generated patches to the border of images --- this space will
+                         then not appear in patches because it is often empty
+    :param shuffle: shuffle the patches so that patches of image pairs are mixed
+    :param **kwargs: remaining parameters passed to MuMapDataset
+    """
+
     def __init__(
         self,
         dataset_dir: str,
         patches_per_image: int = 100,
         patch_size: int = 32,
+        patch_size_z: int = 32,
+        patch_offset: int = 20,
         shuffle: bool = True,
         **kwargs,
     ):
@@ -20,12 +36,16 @@ class MuMapPatchDataset(MuMapDataset):
 
         self.patches_per_image = patches_per_image
         self.patch_size = patch_size
+        self.patch_size_z = patch_size_z
         self.shuffle = shuffle
 
         self.patches = []
         self.generate_patches()
 
     def generate_patches(self):
+        """
+        Pre-compute patches for each image.
+        """
         for _id in self.reconstructions:
             recon = self.reconstructions[_id].squeeze()
             mu_map = self.mu_maps[_id].squeeze()
@@ -34,15 +54,16 @@ class MuMapPatchDataset(MuMapDataset):
                 recon.shape[0] == mu_map.shape[0]
             ), f"Reconstruction and MuMap were not aligned for patch dataset"
 
-            z_range = (0, max(recon.shape[0] - self.patch_size, 0))
+            z_range = (0, max(recon.shape[0] - self.patch_size_z, 0))
             # sometimes the mu_maps have fewer than 32 slices
             # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z
             y_range = (20, recon.shape[1] - self.patch_size - 20)
             x_range = (20, recon.shape[2] - self.patch_size - 20)
 
+            # compute padding for z axis
             padding = [0, 0, 0, 0, 0, 0, 0, 0]
-            if recon.shape[0] < self.patch_size:
-                diff = self.patch_size - recon.shape[0]
+            if recon.shape[0] < self.patch_size_z:
+                diff = self.patch_size_z - recon.shape[0]
                 padding[4] = math.ceil(diff / 2)
                 padding[5] = math.floor(diff / 2)
 
@@ -54,7 +75,8 @@ class MuMapPatchDataset(MuMapDataset):
 
     def __getitem__(self, index: int):
         _id, z, y, x, padding = self.patches[index]
-        s = self.patch_size
+        ps = self.patch_size
+        ps_z = self.patch_size_z
 
         recon = self.reconstructions[_id]
         mu_map = self.mu_maps[_id]
@@ -62,8 +84,8 @@ class MuMapPatchDataset(MuMapDataset):
         recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
         mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
 
-        recon = recon[:, z : z + s, y : y + s, x : x + s]
-        mu_map = mu_map[:, z : z + s, y : y + s, x : x + s]
+        recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps]
+        mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps]
 
         recon, mu_map = self.transform_augmentation(recon, mu_map)
 
@@ -74,10 +96,29 @@ class MuMapPatchDataset(MuMapDataset):
 
 
 if __name__ == "__main__":
+    import argparse
+
     import cv2 as cv
 
     from mu_map.util import to_grayscale, grayscale_to_rgb
 
+    param_keys = list(MuMapPatchDataset.__init__.__annotations__.keys())[1:]
+    param_defaults = MuMapPatchDataset.__init__.__defaults__[1:]
+
+    parser = argparse.ArgumentParser(
+        help="Visualize the patches in a MuMapPatchDataset",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "dataset_dir",
+        type=str,
+        defaul="data/initial/",
+        help="the directory of the dataset",
+    )
+    for key, _default in zip(param_keys, param_defaults):
+        parser.add_argument(f"--{key}", type=type(_default), default=_default)
+    args = parser.parse_args()
+
     wname = "Dataset"
     cv.namedWindow(wname, cv.WINDOW_NORMAL)
     cv.resizeWindow(wname, 1600, 900)