diff --git a/mu_map/data/patch_dataset.py b/mu_map/data/patch_dataset.py index 46caba3af8e969e7f5eaa3af7b484064eee6bf11..cf6cb786551095c314e88f45ce40818488847f83 100644 --- a/mu_map/data/patch_dataset.py +++ b/mu_map/data/patch_dataset.py @@ -1,21 +1,123 @@ +import math +import random + from mu_map.data.datasets import MuMapDataset -class MuMapPatchDataset(MuMapDataset): - def __init__(self, dataset_dir, patches_per_image=100, patch_size=32): +class MuMapPatchDataset(MuMapDataset): + def __init__(self, dataset_dir, patches_per_image=100, patch_size=32, shuffle=True): super().__init__(dataset_dir) - self.patches_per_image=patches_per_image - self.patch_size=patch_size + self.patches_per_image = patches_per_image + self.patch_size = patch_size + self.shuffle = shuffle + + self.patches = [] + self.generate_patches() + + def generate_patches(self): + for i, (recon, mu_map) in enumerate(zip(self.reconstructions, self.mu_maps)): + assert ( + recon.shape[0] == mu_map.shape[0] + ), f"Reconstruction and MuMap were not aligned for patch dataset" + + _id = self.table.iloc[i]["id"] + + z_range = (0, max(recon.shape[0] - self.patch_size, 0)) + # sometimes the mu_maps have fewer than 32 slices + # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z + y_range = (0, recon.shape[1] - self.patch_size) + x_range = (0, recon.shape[2] - self.patch_size) - def __getitem___(self, index:int): - return super()[index] + padding = [(0, 0), (0, 0), (0, 0)] + if recon.shape[0] < self.patch_size: + diff = self.patch_size - recon.shape[0] + padding_bef = math.ceil(diff / 2) + padding_aft = math.floor(diff / 2) + padding[0] = (padding_bef, padding_aft) + + for j in range(self.patches_per_image): + z = random.randint(*z_range) + y = random.randint(*y_range) + x = random.randint(*x_range) + self.patches.append(_id, z, y, x) + + def __getitem___(self, index: int): + _id, z, y, x, padding = self.patches[index] + s = self.patches + + recon = self.reconstructions[_id] + mu_map = self.mu_maps[_id] + + recon = np.pad(recon, padding, mode="constant", constant_values=0) + mu_map = np.pad(mu_map, padding, mode="constant", constant_values=0) + + recon = recon[z : z + s, y : y + s, x : x + s] + mu_map = mu_map[z : z + s, y : y + s, x : x + s] + + return recon, mu_map def __len__(self): - return super().__len__() * self.patches_per_image + return len(self.patches) if __name__ == "__main__": - dataset = MuMapPatchDataset("data/initial/") + import cv2 as cv + + from mu_map.util import to_grayscale, grayscale_to_rgb + + wname = "Dataset" + cv.namedWindow(wname, cv.WINDOW_NORMAL) + cv.resizeWindow(wname, 1600, 900) + + dataset = MuMapPatchDataset("data/initial/", patches_per_image=5) + + print(f"Images (Patches) in the dataset {len(dataset)}") + + def create_image(recon, mu_map, recon_orig, patch, _slice): + s = recon.shape[0] + _id, _, y, x, padding = patch + + _recon_orig = np.pad(recon_orig, patch, mode="constant", constant_values=0) + _recon_orig = recon_orig[_slice] + _recon_orig = to_grayscale(_recon_orig) + _recon_orig = grayscale_to_rgb(_recon_orig) + _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), thickness=1) + _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA) + + _recon = recon[_slice] + _recon = to_grayscale(_recon) + _recon = cv.resize(_recon, (512, 512), cv.INTER_AREA) + _recon = grayscale_to_rgb(_recon) + + _mu_map = mu_map[_slice] + _mu_map = to_grayscale(_mu_map) + _mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA) + _mu_map = grayscale_to_rgb(_mu_map) + + space = np.full((3, 512, 10), 239, np.uint8) + return np.hstack((_recon, space, _mu_map, space, _recon_orig)) + + for i in range(len(dataset)): + _i = 0 + + patch = dataset.patches[i] + _id, z, y, x, padding = patch + print( + "Patch {str(i+1):>len(str(len(dataset)))}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[0][0], padding[0][0]}]" + ) + recon, mu_map = dataset[i] + recon_orig = dataset.reconstructions[_id] + + cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i)) + key = cv.waitKey(100) + + while True: + _i = (_i + 1) % recon.shape[0] + + cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i)) - print(f"Images {len(dataset)}") + if key == ord("n"): + break + elif key == ord("q"): + exit(0)