diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py index 9aade9409018913fedefb1093bc6e4f383ee8497..42a975aa23ea3d0e7374eed652494e29fced18c9 100644 --- a/mu_map/dataset/patches.py +++ b/mu_map/dataset/patches.py @@ -8,11 +8,27 @@ from mu_map.dataset.default import MuMapDataset class MuMapPatchDataset(MuMapDataset): + """ + A wrapper around the MuMapDataset that computes patches for each reconstruction-μ-map pair. + + :param dataset_dir: the directory containing the dataset - is passed to MuMapDataset + :param patches_per_image: the amount of patches to randomly generate for each image + :param patch_size: the size of patches in x- and y-direction + :param patch_size_z: the size of patches in z-direction --- it is a separate parameter because + images are typically shorter in this direction + :param patch_offset: offset of generated patches to the border of images --- this space will + then not appear in patches because it is often empty + :param shuffle: shuffle the patches so that patches of image pairs are mixed + :param **kwargs: remaining parameters passed to MuMapDataset + """ + def __init__( self, dataset_dir: str, patches_per_image: int = 100, patch_size: int = 32, + patch_size_z: int = 32, + patch_offset: int = 20, shuffle: bool = True, **kwargs, ): @@ -20,12 +36,16 @@ class MuMapPatchDataset(MuMapDataset): self.patches_per_image = patches_per_image self.patch_size = patch_size + self.patch_size_z = patch_size_z self.shuffle = shuffle self.patches = [] self.generate_patches() def generate_patches(self): + """ + Pre-compute patches for each image. + """ for _id in self.reconstructions: recon = self.reconstructions[_id].squeeze() mu_map = self.mu_maps[_id].squeeze() @@ -34,15 +54,16 @@ class MuMapPatchDataset(MuMapDataset): recon.shape[0] == mu_map.shape[0] ), f"Reconstruction and MuMap were not aligned for patch dataset" - z_range = (0, max(recon.shape[0] - self.patch_size, 0)) + z_range = (0, max(recon.shape[0] - self.patch_size_z, 0)) # sometimes the mu_maps have fewer than 32 slices # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z y_range = (20, recon.shape[1] - self.patch_size - 20) x_range = (20, recon.shape[2] - self.patch_size - 20) + # compute padding for z axis padding = [0, 0, 0, 0, 0, 0, 0, 0] - if recon.shape[0] < self.patch_size: - diff = self.patch_size - recon.shape[0] + if recon.shape[0] < self.patch_size_z: + diff = self.patch_size_z - recon.shape[0] padding[4] = math.ceil(diff / 2) padding[5] = math.floor(diff / 2) @@ -54,7 +75,8 @@ class MuMapPatchDataset(MuMapDataset): def __getitem__(self, index: int): _id, z, y, x, padding = self.patches[index] - s = self.patch_size + ps = self.patch_size + ps_z = self.patch_size_z recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] @@ -62,8 +84,8 @@ class MuMapPatchDataset(MuMapDataset): recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0) mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0) - recon = recon[:, z : z + s, y : y + s, x : x + s] - mu_map = mu_map[:, z : z + s, y : y + s, x : x + s] + recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps] + mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps] recon, mu_map = self.transform_augmentation(recon, mu_map) @@ -74,10 +96,29 @@ class MuMapPatchDataset(MuMapDataset): if __name__ == "__main__": + import argparse + import cv2 as cv from mu_map.util import to_grayscale, grayscale_to_rgb + param_keys = list(MuMapPatchDataset.__init__.__annotations__.keys())[1:] + param_defaults = MuMapPatchDataset.__init__.__defaults__[1:] + + parser = argparse.ArgumentParser( + help="Visualize the patches in a MuMapPatchDataset", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "dataset_dir", + type=str, + defaul="data/initial/", + help="the directory of the dataset", + ) + for key, _default in zip(param_keys, param_defaults): + parser.add_argument(f"--{key}", type=type(_default), default=_default) + args = parser.parse_args() + wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900)