Skip to content
Snippets Groups Projects 6.96 KiB
import math
import random

import numpy as np
import torch

from mu_map.dataset.default import MuMapDataset

class MuMapPatchDataset(MuMapDataset):
    A wrapper around the MuMapDataset that computes patches for each reconstruction-μ-map pair.

    :param dataset_dir: the directory containing the dataset - is passed to MuMapDataset
    :param patches_per_image: the amount of patches to randomly generate for each image
    :param patch_size: the size of patches in x- and y-direction
    :param patch_size_z: the size of patches in z-direction --- it is a separate parameter because
                         images are typically shorter in this direction
    :param patch_offset: offset of generated patches to the border of images --- this space will
                         then not appear in patches because it is often empty
    :param shuffle: shuffle the patches so that patches of image pairs are mixed
    :param **kwargs: remaining parameters passed to MuMapDataset

    def __init__(
        dataset_dir: str,
        patches_per_image: int = 100,
        patch_size: int = 32,
        patch_size_z: int = 32,
        patch_offset: int = 20,
        shuffle: bool = True,
        super().__init__(dataset_dir, **kwargs)

        self.patches_per_image = patches_per_image
        self.patch_size = patch_size
        self.patch_size_z = patch_size_z
        self.patch_offset = patch_offset
        self.shuffle = shuffle

        self.patches = []

    def generate_patches(self):
        Pre-compute patches for each image.
        for _id in self.reconstructions:
            recon = self.reconstructions[_id].squeeze()
            mu_map = self.mu_maps[_id].squeeze()

            assert (
                recon.shape[0] == mu_map.shape[0]
            ), f"Reconstruction and MuMap were not aligned for patch dataset"

            z_range = (0, max(recon.shape[0] - self.patch_size_z, 0))
            # sometimes the mu_maps have fewer than 32 slices
            # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z
            y_range = (
                recon.shape[1] - self.patch_size - self.patch_offset,
            x_range = (
                recon.shape[2] - self.patch_size - self.patch_offset,

            # compute padding for z axis
            padding = [0, 0, 0, 0, 0, 0, 0, 0]
            if recon.shape[0] < self.patch_size_z:
                diff = self.patch_size_z - recon.shape[0]
                padding[4] = math.ceil(diff / 2)
                padding[5] = math.floor(diff / 2)

            for j in range(self.patches_per_image):
                z = random.randint(*z_range)
                y = random.randint(*y_range)
                x = random.randint(*x_range)
                self.patches.append((_id, z, y, x, padding))

    def __getitem__(self, index: int):
        _id, z, y, x, padding = self.patches[index]
        ps = self.patch_size
        ps_z = self.patch_size_z

        recon, mu_map = super().__getitem__(index)

        recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
        mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)

        recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps]
        mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps]

        return recon, mu_map

    def __len__(self):
        return len(self.patches)

if __name__ == "__main__":
    import argparse

    import cv2 as cv

    from mu_map.util import to_grayscale, grayscale_to_rgb

    param_keys = list(MuMapPatchDataset.__init__.__annotations__.keys())[1:]
    param_defaults = MuMapPatchDataset.__init__.__defaults__
    param_help = [
        "number of patches for each image",
        "patch size in x- and y-direction",
        "patch size in z-direction",
        "offset to ignore image borders",
        "shuffle the dataset",

    parser = argparse.ArgumentParser(
        description="Visualize the patches in a MuMapPatchDataset",
        help="the directory of the dataset",
    for key, _default, _help in zip(param_keys, param_defaults, param_help):
            f"--{key}", type=type(_default), default=_default, help=_help
    args = parser.parse_args()

    wname = "Dataset"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)

    dataset = MuMapPatchDataset(

    print(f"Images (Patches) in the dataset {len(dataset)}")

    def create_image(recon, mu_map, recon_orig, patch, _slice):
        s = dataset.patch_size
        _id, _, y, x, padding = patch

        _recon_orig = recon_orig[_slice]
        _recon_orig = to_grayscale(_recon_orig)
        _recon_orig = grayscale_to_rgb(_recon_orig)
        _recon_orig = cv.rectangle(
            _recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1
        _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA)

        _recon = recon[_slice]
        _recon = to_grayscale(_recon)
        _recon = cv.resize(_recon, (512, 512), cv.INTER_AREA)
        _recon = grayscale_to_rgb(_recon)

        _mu_map = mu_map[_slice]
        _mu_map = to_grayscale(_mu_map)
        _mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA)
        _mu_map = grayscale_to_rgb(_mu_map)

        space = np.full((512, 10, 3), 239, np.uint8)
        return np.hstack((_recon, space, _mu_map, space, _recon_orig))

    for i in range(len(dataset)):
        _i = 0

        patch = dataset.patches[i]
        _id, z, y, x, padding = patch
            f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[5], padding[6]}]"
        recon, mu_map = dataset[i]
        recon = recon.squeeze().numpy()
        mu_map = mu_map.squeeze().numpy()

        recon_orig = dataset.reconstructions[_id]
        recon_orig = torch.nn.functional.pad(
            recon_orig, padding, mode="constant", value=0
        recon_orig = recon_orig.squeeze().numpy()

        cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
        key = cv.waitKey(100)

        while True:
            _i = (_i + 1) % recon.shape[0]

            cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
            key = cv.waitKey(100)

            if key == ord("n"):
            elif key == ord("q"):