diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py index 42a975aa23ea3d0e7374eed652494e29fced18c9..d007d98b3906f5439352e8086444ac6b23694733 100644 --- a/mu_map/dataset/patches.py +++ b/mu_map/dataset/patches.py @@ -37,6 +37,7 @@ class MuMapPatchDataset(MuMapDataset): self.patches_per_image = patches_per_image self.patch_size = patch_size self.patch_size_z = patch_size_z + self.patch_offset = patch_offset self.shuffle = shuffle self.patches = [] @@ -57,8 +58,14 @@ class MuMapPatchDataset(MuMapDataset): z_range = (0, max(recon.shape[0] - self.patch_size_z, 0)) # sometimes the mu_maps have fewer than 32 slices # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z - y_range = (20, recon.shape[1] - self.patch_size - 20) - x_range = (20, recon.shape[2] - self.patch_size - 20) + y_range = ( + self.patch_offset, + recon.shape[1] - self.patch_size - self.patch_offset, + ) + x_range = ( + self.patch_offset, + recon.shape[2] - self.patch_size - self.patch_offset, + ) # compute padding for z axis padding = [0, 0, 0, 0, 0, 0, 0, 0] @@ -103,27 +110,44 @@ if __name__ == "__main__": from mu_map.util import to_grayscale, grayscale_to_rgb param_keys = list(MuMapPatchDataset.__init__.__annotations__.keys())[1:] - param_defaults = MuMapPatchDataset.__init__.__defaults__[1:] + param_defaults = MuMapPatchDataset.__init__.__defaults__ + param_help = [ + "number of patches for each image", + "patch size in x- and y-direction", + "patch size in z-direction", + "offset to ignore image borders", + "shuffle the dataset", + ] parser = argparse.ArgumentParser( - help="Visualize the patches in a MuMapPatchDataset", + description="Visualize the patches in a MuMapPatchDataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - "dataset_dir", + "--dataset_dir", type=str, - defaul="data/initial/", + default="data/initial/", help="the directory of the dataset", ) - for key, _default in zip(param_keys, param_defaults): - parser.add_argument(f"--{key}", type=type(_default), default=_default) + for key, _default, _help in zip(param_keys, param_defaults, param_help): + parser.add_argument( + f"--{key}", type=type(_default), default=_default, help=_help + ) args = parser.parse_args() + print(args) wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) - dataset = MuMapPatchDataset("data/initial/", patches_per_image=1) + dataset = MuMapPatchDataset( + "data/initial/", + patches_per_image=args.patches_per_image, + patch_size=args.patch_size, + patch_size_z=args.patch_size_z, + patch_offset=args.patch_offset, + shuffle=args.shuffle, + ) print(f"Images (Patches) in the dataset {len(dataset)}")