diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py index 1568919a73905899e55062221701388ec0b71eb1..5c52917c3a1bd22a78a2c46b5a0998137bac339d 100644 --- a/mu_map/dataset/default.py +++ b/mu_map/dataset/default.py @@ -106,6 +106,10 @@ class MuMapDataset(Dataset): self.mu_maps[_id] = mu_map self.reconstructions[_id] = recon + def pre_load_images(self): + for _id in self.table[headers.id]: + self.load_image(_id) + def __getitem__(self, index: int): row = self.table.iloc[index] _id = row[headers.id] diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py index d007d98b3906f5439352e8086444ac6b23694733..d1d9a11237b868888c67d2396b9562f6f6a03c8f 100644 --- a/mu_map/dataset/patches.py +++ b/mu_map/dataset/patches.py @@ -39,6 +39,7 @@ class MuMapPatchDataset(MuMapDataset): self.patch_size_z = patch_size_z self.patch_offset = patch_offset self.shuffle = shuffle + super().pre_load_images() self.patches = [] self.generate_patches() @@ -85,8 +86,7 @@ class MuMapPatchDataset(MuMapDataset): ps = self.patch_size ps_z = self.patch_size_z - recon = self.reconstructions[_id] - mu_map = self.mu_maps[_id] + recon, mu_map = super().__getitem__(index) recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0) mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0) @@ -94,8 +94,6 @@ class MuMapPatchDataset(MuMapDataset): recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps] mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps] - recon, mu_map = self.transform_augmentation(recon, mu_map) - return recon, mu_map def __len__(self):