Skip to content
Snippets Groups Projects
Commit c4ac5562 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update parameters to patches dataset

parent e8f15d80
No related branches found
No related tags found
No related merge requests found
......@@ -8,11 +8,27 @@ from mu_map.dataset.default import MuMapDataset
class MuMapPatchDataset(MuMapDataset):
"""
A wrapper around the MuMapDataset that computes patches for each reconstruction-μ-map pair.
:param dataset_dir: the directory containing the dataset - is passed to MuMapDataset
:param patches_per_image: the amount of patches to randomly generate for each image
:param patch_size: the size of patches in x- and y-direction
:param patch_size_z: the size of patches in z-direction --- it is a separate parameter because
images are typically shorter in this direction
:param patch_offset: offset of generated patches to the border of images --- this space will
then not appear in patches because it is often empty
:param shuffle: shuffle the patches so that patches of image pairs are mixed
:param **kwargs: remaining parameters passed to MuMapDataset
"""
def __init__(
self,
dataset_dir: str,
patches_per_image: int = 100,
patch_size: int = 32,
patch_size_z: int = 32,
patch_offset: int = 20,
shuffle: bool = True,
**kwargs,
):
......@@ -20,12 +36,16 @@ class MuMapPatchDataset(MuMapDataset):
self.patches_per_image = patches_per_image
self.patch_size = patch_size
self.patch_size_z = patch_size_z
self.shuffle = shuffle
self.patches = []
self.generate_patches()
def generate_patches(self):
"""
Pre-compute patches for each image.
"""
for _id in self.reconstructions:
recon = self.reconstructions[_id].squeeze()
mu_map = self.mu_maps[_id].squeeze()
......@@ -34,15 +54,16 @@ class MuMapPatchDataset(MuMapDataset):
recon.shape[0] == mu_map.shape[0]
), f"Reconstruction and MuMap were not aligned for patch dataset"
z_range = (0, max(recon.shape[0] - self.patch_size, 0))
z_range = (0, max(recon.shape[0] - self.patch_size_z, 0))
# sometimes the mu_maps have fewer than 32 slices
# in this case the z-axis will be padded to the patch size, but this means we only have a single option for z
y_range = (20, recon.shape[1] - self.patch_size - 20)
x_range = (20, recon.shape[2] - self.patch_size - 20)
# compute padding for z axis
padding = [0, 0, 0, 0, 0, 0, 0, 0]
if recon.shape[0] < self.patch_size:
diff = self.patch_size - recon.shape[0]
if recon.shape[0] < self.patch_size_z:
diff = self.patch_size_z - recon.shape[0]
padding[4] = math.ceil(diff / 2)
padding[5] = math.floor(diff / 2)
......@@ -54,7 +75,8 @@ class MuMapPatchDataset(MuMapDataset):
def __getitem__(self, index: int):
_id, z, y, x, padding = self.patches[index]
s = self.patch_size
ps = self.patch_size
ps_z = self.patch_size_z
recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id]
......@@ -62,8 +84,8 @@ class MuMapPatchDataset(MuMapDataset):
recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
recon = recon[:, z : z + s, y : y + s, x : x + s]
mu_map = mu_map[:, z : z + s, y : y + s, x : x + s]
recon = recon[:, z : z + ps_z, y : y + ps, x : x + ps]
mu_map = mu_map[:, z : z + ps_z, y : y + ps, x : x + ps]
recon, mu_map = self.transform_augmentation(recon, mu_map)
......@@ -74,10 +96,29 @@ class MuMapPatchDataset(MuMapDataset):
if __name__ == "__main__":
import argparse
import cv2 as cv
from mu_map.util import to_grayscale, grayscale_to_rgb
param_keys = list(MuMapPatchDataset.__init__.__annotations__.keys())[1:]
param_defaults = MuMapPatchDataset.__init__.__defaults__[1:]
parser = argparse.ArgumentParser(
help="Visualize the patches in a MuMapPatchDataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dataset_dir",
type=str,
defaul="data/initial/",
help="the directory of the dataset",
)
for key, _default in zip(param_keys, param_defaults):
parser.add_argument(f"--{key}", type=type(_default), default=_default)
args = parser.parse_args()
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment