diff --git a/mu_map/data/patch_dataset.py b/mu_map/data/patch_dataset.py index 6faf07a6a34e9796d4fb0d066e422ab66edd8613..934ce823e6996db784ac424f5c46fe13a0fcbd8d 100644 --- a/mu_map/data/patch_dataset.py +++ b/mu_map/data/patch_dataset.py @@ -2,6 +2,7 @@ import math import random import numpy as np +import torch from mu_map.data.datasets import MuMapDataset @@ -19,8 +20,8 @@ class MuMapPatchDataset(MuMapDataset): def generate_patches(self): for _id in self.reconstructions: - recon = self.reconstructions[_id] - mu_map = self.mu_maps[_id] + recon = self.reconstructions[_id].squeeze() + mu_map = self.mu_maps[_id].squeeze() assert ( recon.shape[0] == mu_map.shape[0] @@ -32,12 +33,11 @@ class MuMapPatchDataset(MuMapDataset): y_range = (20, recon.shape[1] - self.patch_size - 20) x_range = (20, recon.shape[2] - self.patch_size - 20) - padding = [(0, 0), (0, 0), (0, 0)] + padding = [0, 0, 0, 0, 0, 0, 0, 0] if recon.shape[0] < self.patch_size: diff = self.patch_size - recon.shape[0] - padding_bef = math.ceil(diff / 2) - padding_aft = math.floor(diff / 2) - padding[0] = (padding_bef, padding_aft) + padding[4] = math.ceil(diff / 2) + padding[5] = math.floor(diff / 2) for j in range(self.patches_per_image): z = random.randint(*z_range) @@ -52,11 +52,11 @@ class MuMapPatchDataset(MuMapDataset): recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] - recon = np.pad(recon, padding, mode="constant", constant_values=0) - mu_map = np.pad(mu_map, padding, mode="constant", constant_values=0) + recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0) + mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0) - recon = recon[z : z + s, y : y + s, x : x + s] - mu_map = mu_map[z : z + s, y : y + s, x : x + s] + recon = recon[:, z : z + s, y : y + s, x : x + s] + mu_map = mu_map[:, z : z + s, y : y + s, x : x + s] return recon, mu_map @@ -81,8 +81,7 @@ if __name__ == "__main__": s = dataset.patch_size _id, _, y, x, padding = patch - _recon_orig = np.pad(recon_orig, padding, mode="constant", constant_values=0) - _recon_orig = _recon_orig[_slice] + _recon_orig = recon_orig[_slice] _recon_orig = to_grayscale(_recon_orig) _recon_orig = grayscale_to_rgb(_recon_orig) _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1) @@ -107,10 +106,15 @@ if __name__ == "__main__": patch = dataset.patches[i] _id, z, y, x, padding = patch print( - f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[0][0], padding[0][0]}]" + f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[5], padding[6]}]" ) recon, mu_map = dataset[i] + recon = recon.squeeze().numpy() + mu_map = mu_map.squeeze().numpy() + recon_orig = dataset.reconstructions[_id] + recon_orig = torch.nn.functional.pad(recon_orig, padding, mode="constant", value=0) + recon_orig = recon_orig.squeeze().numpy() cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i)) key = cv.waitKey(100)