Skip to content
Snippets Groups Projects
Commit 593a7109 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

fix pre-loading for patches dataset

parent 0998d095
No related branches found
No related tags found
No related merge requests found
...@@ -106,6 +106,10 @@ class MuMapDataset(Dataset): ...@@ -106,6 +106,10 @@ class MuMapDataset(Dataset):
self.mu_maps[_id] = mu_map self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon self.reconstructions[_id] = recon
def pre_load_images(self):
for _id in self.table[headers.id]:
self.load_image(_id)
def __getitem__(self, index: int): def __getitem__(self, index: int):
row = self.table.iloc[index] row = self.table.iloc[index]
_id = row[headers.id] _id = row[headers.id]
......
...@@ -39,6 +39,7 @@ class MuMapPatchDataset(MuMapDataset): ...@@ -39,6 +39,7 @@ class MuMapPatchDataset(MuMapDataset):
self.patch_size_z = patch_size_z self.patch_size_z = patch_size_z
self.patch_offset = patch_offset self.patch_offset = patch_offset
self.shuffle = shuffle self.shuffle = shuffle
super().pre_load_images()
self.patches = [] self.patches = []
self.generate_patches() self.generate_patches()
...@@ -85,8 +86,7 @@ class MuMapPatchDataset(MuMapDataset): ...@@ -85,8 +86,7 @@ class MuMapPatchDataset(MuMapDataset):
ps = self.patch_size ps = self.patch_size
ps_z = self.patch_size_z ps_z = self.patch_size_z
recon = self.reconstructions[_id] recon, mu_map = super().__getitem__(index)
mu_map = self.mu_maps[_id]
recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0) recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0) mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
...@@ -94,8 +94,6 @@ class MuMapPatchDataset(MuMapDataset): ...@@ -94,8 +94,6 @@ class MuMapPatchDataset(MuMapDataset):
recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps] recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps]
mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps] mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps]
recon, mu_map = self.transform_augmentation(recon, mu_map)
return recon, mu_map return recon, mu_map
def __len__(self): def __len__(self):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment