Skip to content
Snippets Groups Projects
patches.py 7.44 KiB
Newer Older
  • Learn to ignore specific revisions
  • import math
    import random
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    import numpy as np
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    
    from mu_map.dataset.default import MuMapDataset
    
    class MuMapPatchDataset(MuMapDataset):
    
        """
        A wrapper around the MuMapDataset that computes patches for each reconstruction-μ-map pair.
    
    
        Parameters
        ----------
        dataset_dir: str
            the directory containing the dataset - is passed to MuMapDataset
        patches_per_image: int
            the amount of patches to randomly generate for each image
        patch_size: int
            the size of patches in x- and y-direction
        patch_size_z: int
            the size of patches in z-direction --- it is a separate parameter because
            images are typically shorter in this direction
        patch_offset: int
            offset of generated patches to the border of images --- this space will
            then not appear in patches because it is often empty
        shuffle: bool
            shuffle the patches so that patches of image pairs are mixed
        **kwargs:
            remaining parameters passed to MuMapDataset
    
        def __init__(
            self,
            dataset_dir: str,
            patches_per_image: int = 100,
            patch_size: int = 32,
    
            patch_size_z: int = 32,
            patch_offset: int = 20,
    
            shuffle: bool = True,
            **kwargs,
        ):
            super().__init__(dataset_dir, **kwargs)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.kwargs = kwargs
    
            self.patches_per_image = patches_per_image
            self.patch_size = patch_size
    
            self.patch_size_z = patch_size_z
    
            self.patch_offset = patch_offset
    
            self.shuffle = shuffle
    
            super().pre_load_images()
    
    
            self.patches = []
            self.generate_patches()
    
    
        def copy(self, split_name: str):
    
            kwargs = self.kwargs.copy()
            kwargs["split_name"] = split_name
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            return MuMapPatchDataset(
                dataset_dir=self.dir,
                patches_per_image=self.patches_per_image,
                patch_size=self.patch_size,
                patch_size_z=self.patch_size_z,
                patch_offset=self.patch_offset,
                shuffle=self.shuffle,
    
        def generate_patches(self):
    
            """
            Pre-compute patches for each image.
            """
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            for _id in self.reconstructions:
    
                recon = self.reconstructions[_id].squeeze()
                mu_map = self.mu_maps[_id].squeeze()
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    
                assert (
                    recon.shape[0] == mu_map.shape[0]
                ), f"Reconstruction and MuMap were not aligned for patch dataset"
    
    
                z_range = (0, max(recon.shape[0] - self.patch_size_z, 0))
    
                # sometimes the mu_maps have fewer than 32 slices
                # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z
    
                y_range = (
                    self.patch_offset,
                    recon.shape[1] - self.patch_size - self.patch_offset,
                )
                x_range = (
                    self.patch_offset,
                    recon.shape[2] - self.patch_size - self.patch_offset,
                )
    
                # compute padding for z axis
    
                padding = [0, 0, 0, 0, 0, 0, 0, 0]
    
                if recon.shape[0] < self.patch_size_z:
                    diff = self.patch_size_z - recon.shape[0]
    
                    padding[4] = math.ceil(diff / 2)
                    padding[5] = math.floor(diff / 2)
    
    
                for j in range(self.patches_per_image):
                    z = random.randint(*z_range)
                    y = random.randint(*y_range)
                    x = random.randint(*x_range)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    self.patches.append((_id, z, y, x, padding))
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        def __getitem__(self, index: int):
    
            _id, z, y, x, padding = self.patches[index]
    
            ps = self.patch_size
            ps_z = self.patch_size_z
    
            recon, mu_map = super().get_item_by_id(_id)
    
            recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
            mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
    
            recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps]
            mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps]
    
    
            return recon, mu_map
    
            return len(self.patches)
    
        import cv2 as cv
    
        from mu_map.util import to_grayscale, grayscale_to_rgb
    
    
        param_keys = list(MuMapPatchDataset.__init__.__annotations__.keys())[1:]
    
        param_defaults = MuMapPatchDataset.__init__.__defaults__
        param_help = [
            "number of patches for each image",
            "patch size in x- and y-direction",
            "patch size in z-direction",
            "offset to ignore image borders",
            "shuffle the dataset",
        ]
    
    
        parser = argparse.ArgumentParser(
    
            description="Visualize the patches in a MuMapPatchDataset",
    
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
    
            "--dataset_dir",
    
            default="data/second/",
    
            help="the directory of the dataset",
        )
    
        for key, _default, _help in zip(param_keys, param_defaults, param_help):
            parser.add_argument(
                f"--{key}", type=type(_default), default=_default, help=_help
            )
    
        args = parser.parse_args()
    
        wname = "Dataset"
        cv.namedWindow(wname, cv.WINDOW_NORMAL)
        cv.resizeWindow(wname, 1600, 900)
    
    
        dataset = MuMapPatchDataset(
            "data/initial/",
            patches_per_image=args.patches_per_image,
            patch_size=args.patch_size,
            patch_size_z=args.patch_size_z,
            patch_offset=args.patch_offset,
            shuffle=args.shuffle,
        )
    
    
        print(f"Images (Patches) in the dataset {len(dataset)}")
    
        def create_image(recon, mu_map, recon_orig, patch, _slice):
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            s = dataset.patch_size
    
            _id, _, y, x, padding = patch
    
    
            _recon_orig = recon_orig[_slice]
    
            _recon_orig = to_grayscale(_recon_orig)
            _recon_orig = grayscale_to_rgb(_recon_orig)
    
            _recon_orig = cv.rectangle(
                _recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1
            )
    
            _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA)
    
            _recon = recon[_slice]
            _recon = to_grayscale(_recon)
            _recon = cv.resize(_recon, (512, 512), cv.INTER_AREA)
            _recon = grayscale_to_rgb(_recon)
    
            _mu_map = mu_map[_slice]
            _mu_map = to_grayscale(_mu_map)
            _mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA)
            _mu_map = grayscale_to_rgb(_mu_map)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            space = np.full((512, 10, 3), 239, np.uint8)
    
            return np.hstack((_recon, space, _mu_map, space, _recon_orig))
    
        for i in range(len(dataset)):
            _i = 0
    
            patch = dataset.patches[i]
            _id, z, y, x, padding = patch
            print(
    
                f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[5], padding[6]}]"
    
            )
            recon, mu_map = dataset[i]
    
            recon = recon.squeeze().numpy()
            mu_map = mu_map.squeeze().numpy()
    
    
            recon_orig = dataset.reconstructions[_id]
    
            recon_orig = torch.nn.functional.pad(
                recon_orig, padding, mode="constant", value=0
            )
    
            recon_orig = recon_orig.squeeze().numpy()
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
    
            key = cv.waitKey(100)
    
            while True:
                _i = (_i + 1) % recon.shape[0]
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
                key = cv.waitKey(100)
    
                if key == ord("n"):
                    break
                elif key == ord("q"):
                    exit(0)