-
Tamino Huxohl authoredTamino Huxohl authored
patches.py 6.96 KiB
import math
import random
import numpy as np
import torch
from mu_map.dataset.default import MuMapDataset
class MuMapPatchDataset(MuMapDataset):
"""
A wrapper around the MuMapDataset that computes patches for each reconstruction-μ-map pair.
:param dataset_dir: the directory containing the dataset - is passed to MuMapDataset
:param patches_per_image: the amount of patches to randomly generate for each image
:param patch_size: the size of patches in x- and y-direction
:param patch_size_z: the size of patches in z-direction --- it is a separate parameter because
images are typically shorter in this direction
:param patch_offset: offset of generated patches to the border of images --- this space will
then not appear in patches because it is often empty
:param shuffle: shuffle the patches so that patches of image pairs are mixed
:param **kwargs: remaining parameters passed to MuMapDataset
"""
def __init__(
self,
dataset_dir: str,
patches_per_image: int = 100,
patch_size: int = 32,
patch_size_z: int = 32,
patch_offset: int = 20,
shuffle: bool = True,
**kwargs,
):
super().__init__(dataset_dir, **kwargs)
self.patches_per_image = patches_per_image
self.patch_size = patch_size
self.patch_size_z = patch_size_z
self.patch_offset = patch_offset
self.shuffle = shuffle
super().pre_load_images()
self.patches = []
self.generate_patches()
def generate_patches(self):
"""
Pre-compute patches for each image.
"""
for _id in self.reconstructions:
recon = self.reconstructions[_id].squeeze()
mu_map = self.mu_maps[_id].squeeze()
assert (
recon.shape[0] == mu_map.shape[0]
), f"Reconstruction and MuMap were not aligned for patch dataset"
z_range = (0, max(recon.shape[0] - self.patch_size_z, 0))
# sometimes the mu_maps have fewer than 32 slices
# in this case the z-axis will be padded to the patch size, but this means we only have a single option for z
y_range = (
self.patch_offset,
recon.shape[1] - self.patch_size - self.patch_offset,
)
x_range = (
self.patch_offset,
recon.shape[2] - self.patch_size - self.patch_offset,
)
# compute padding for z axis
padding = [0, 0, 0, 0, 0, 0, 0, 0]
if recon.shape[0] < self.patch_size_z:
diff = self.patch_size_z - recon.shape[0]
padding[4] = math.ceil(diff / 2)
padding[5] = math.floor(diff / 2)
for j in range(self.patches_per_image):
z = random.randint(*z_range)
y = random.randint(*y_range)
x = random.randint(*x_range)
self.patches.append((_id, z, y, x, padding))
def __getitem__(self, index: int):
_id, z, y, x, padding = self.patches[index]
ps = self.patch_size
ps_z = self.patch_size_z
recon, mu_map = super().__getitem__(index)
recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps]
mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps]
return recon, mu_map
def __len__(self):
return len(self.patches)
if __name__ == "__main__":
import argparse
import cv2 as cv
from mu_map.util import to_grayscale, grayscale_to_rgb
param_keys = list(MuMapPatchDataset.__init__.__annotations__.keys())[1:]
param_defaults = MuMapPatchDataset.__init__.__defaults__
param_help = [
"number of patches for each image",
"patch size in x- and y-direction",
"patch size in z-direction",
"offset to ignore image borders",
"shuffle the dataset",
]
parser = argparse.ArgumentParser(
description="Visualize the patches in a MuMapPatchDataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dataset_dir",
type=str,
default="data/initial/",
help="the directory of the dataset",
)
for key, _default, _help in zip(param_keys, param_defaults, param_help):
parser.add_argument(
f"--{key}", type=type(_default), default=_default, help=_help
)
args = parser.parse_args()
print(args)
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
dataset = MuMapPatchDataset(
"data/initial/",
patches_per_image=args.patches_per_image,
patch_size=args.patch_size,
patch_size_z=args.patch_size_z,
patch_offset=args.patch_offset,
shuffle=args.shuffle,
)
print(f"Images (Patches) in the dataset {len(dataset)}")
def create_image(recon, mu_map, recon_orig, patch, _slice):
s = dataset.patch_size
_id, _, y, x, padding = patch
_recon_orig = recon_orig[_slice]
_recon_orig = to_grayscale(_recon_orig)
_recon_orig = grayscale_to_rgb(_recon_orig)
_recon_orig = cv.rectangle(
_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1
)
_recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA)
_recon = recon[_slice]
_recon = to_grayscale(_recon)
_recon = cv.resize(_recon, (512, 512), cv.INTER_AREA)
_recon = grayscale_to_rgb(_recon)
_mu_map = mu_map[_slice]
_mu_map = to_grayscale(_mu_map)
_mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA)
_mu_map = grayscale_to_rgb(_mu_map)
space = np.full((512, 10, 3), 239, np.uint8)
return np.hstack((_recon, space, _mu_map, space, _recon_orig))
for i in range(len(dataset)):
_i = 0
patch = dataset.patches[i]
_id, z, y, x, padding = patch
print(
f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[5], padding[6]}]"
)
recon, mu_map = dataset[i]
recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
recon_orig = dataset.reconstructions[_id]
recon_orig = torch.nn.functional.pad(
recon_orig, padding, mode="constant", value=0
)
recon_orig = recon_orig.squeeze().numpy()
cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
key = cv.waitKey(100)
while True:
_i = (_i + 1) % recon.shape[0]
cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
key = cv.waitKey(100)
if key == ord("n"):
break
elif key == ord("q"):
exit(0)