diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py
index df9f2c971b991c2278a7e4af9a6376c601af6d71..530e61abff3d658608e2a0557b2d34abf80e0036 100644
--- a/mu_map/dataset/patches.py
+++ b/mu_map/dataset/patches.py
@@ -8,8 +8,15 @@ from mu_map.dataset.default import MuMapDataset
 
 
 class MuMapPatchDataset(MuMapDataset):
-    def __init__(self, dataset_dir, patches_per_image=100, patch_size=32, shuffle=True):
-        super().__init__(dataset_dir)
+    def __init__(
+        self,
+        dataset_dir: str,
+        patches_per_image: int = 100,
+        patch_size: int = 32,
+        shuffle: bool = True,
+        **kwargs,
+    ):
+        super().__init__(dataset_dir, **kwargs)
 
         self.patches_per_image = patches_per_image
         self.patch_size = patch_size
@@ -84,7 +91,9 @@ if __name__ == "__main__":
         _recon_orig = recon_orig[_slice]
         _recon_orig = to_grayscale(_recon_orig)
         _recon_orig = grayscale_to_rgb(_recon_orig)
-        _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1)
+        _recon_orig = cv.rectangle(
+            _recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1
+        )
         _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA)
 
         _recon = recon[_slice]
@@ -113,7 +122,9 @@ if __name__ == "__main__":
         mu_map = mu_map.squeeze().numpy()
 
         recon_orig = dataset.reconstructions[_id]
-        recon_orig = torch.nn.functional.pad(recon_orig, padding, mode="constant", value=0)
+        recon_orig = torch.nn.functional.pad(
+            recon_orig, padding, mode="constant", value=0
+        )
         recon_orig = recon_orig.squeeze().numpy()
 
         cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))