diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py
index 1568919a73905899e55062221701388ec0b71eb1..5c52917c3a1bd22a78a2c46b5a0998137bac339d 100644
--- a/mu_map/dataset/default.py
+++ b/mu_map/dataset/default.py
@@ -106,6 +106,10 @@ class MuMapDataset(Dataset):
         self.mu_maps[_id] = mu_map
         self.reconstructions[_id] = recon
 
+    def pre_load_images(self):
+        for _id in self.table[headers.id]:
+            self.load_image(_id)
+
     def __getitem__(self, index: int):
         row = self.table.iloc[index]
         _id = row[headers.id]
diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py
index d007d98b3906f5439352e8086444ac6b23694733..d1d9a11237b868888c67d2396b9562f6f6a03c8f 100644
--- a/mu_map/dataset/patches.py
+++ b/mu_map/dataset/patches.py
@@ -39,6 +39,7 @@ class MuMapPatchDataset(MuMapDataset):
         self.patch_size_z = patch_size_z
         self.patch_offset = patch_offset
         self.shuffle = shuffle
+        super().pre_load_images()
 
         self.patches = []
         self.generate_patches()
@@ -85,8 +86,7 @@ class MuMapPatchDataset(MuMapDataset):
         ps = self.patch_size
         ps_z = self.patch_size_z
 
-        recon = self.reconstructions[_id]
-        mu_map = self.mu_maps[_id]
+        recon, mu_map = super().__getitem__(index)
 
         recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
         mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
@@ -94,8 +94,6 @@ class MuMapPatchDataset(MuMapDataset):
         recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps]
         mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps]
 
-        recon, mu_map = self.transform_augmentation(recon, mu_map)
-
         return recon, mu_map
 
     def __len__(self):