Skip to content
Snippets Groups Projects
Commit eeafcf78 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

fix patch dataset

parent f08f9c84
No related branches found
No related tags found
No related merge requests found
import math import math
import random import random
import numpy as np
from mu_map.data.datasets import MuMapDataset from mu_map.data.datasets import MuMapDataset
...@@ -16,18 +18,19 @@ class MuMapPatchDataset(MuMapDataset): ...@@ -16,18 +18,19 @@ class MuMapPatchDataset(MuMapDataset):
self.generate_patches() self.generate_patches()
def generate_patches(self): def generate_patches(self):
for i, (recon, mu_map) in enumerate(zip(self.reconstructions, self.mu_maps)): for _id in self.reconstructions:
recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id]
assert ( assert (
recon.shape[0] == mu_map.shape[0] recon.shape[0] == mu_map.shape[0]
), f"Reconstruction and MuMap were not aligned for patch dataset" ), f"Reconstruction and MuMap were not aligned for patch dataset"
_id = self.table.iloc[i]["id"]
z_range = (0, max(recon.shape[0] - self.patch_size, 0)) z_range = (0, max(recon.shape[0] - self.patch_size, 0))
# sometimes the mu_maps have fewer than 32 slices # sometimes the mu_maps have fewer than 32 slices
# in this case the z-axis will be padded to the patch size, but this means we only have a single option for z # in this case the z-axis will be padded to the patch size, but this means we only have a single option for z
y_range = (0, recon.shape[1] - self.patch_size) y_range = (20, recon.shape[1] - self.patch_size - 20)
x_range = (0, recon.shape[2] - self.patch_size) x_range = (20, recon.shape[2] - self.patch_size - 20)
padding = [(0, 0), (0, 0), (0, 0)] padding = [(0, 0), (0, 0), (0, 0)]
if recon.shape[0] < self.patch_size: if recon.shape[0] < self.patch_size:
...@@ -40,11 +43,11 @@ class MuMapPatchDataset(MuMapDataset): ...@@ -40,11 +43,11 @@ class MuMapPatchDataset(MuMapDataset):
z = random.randint(*z_range) z = random.randint(*z_range)
y = random.randint(*y_range) y = random.randint(*y_range)
x = random.randint(*x_range) x = random.randint(*x_range)
self.patches.append(_id, z, y, x) self.patches.append((_id, z, y, x, padding))
def __getitem___(self, index: int): def __getitem__(self, index: int):
_id, z, y, x, padding = self.patches[index] _id, z, y, x, padding = self.patches[index]
s = self.patches s = self.patch_size
recon = self.reconstructions[_id] recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id] mu_map = self.mu_maps[_id]
...@@ -70,19 +73,19 @@ if __name__ == "__main__": ...@@ -70,19 +73,19 @@ if __name__ == "__main__":
cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900) cv.resizeWindow(wname, 1600, 900)
dataset = MuMapPatchDataset("data/initial/", patches_per_image=5) dataset = MuMapPatchDataset("data/initial/", patches_per_image=1)
print(f"Images (Patches) in the dataset {len(dataset)}") print(f"Images (Patches) in the dataset {len(dataset)}")
def create_image(recon, mu_map, recon_orig, patch, _slice): def create_image(recon, mu_map, recon_orig, patch, _slice):
s = recon.shape[0] s = dataset.patch_size
_id, _, y, x, padding = patch _id, _, y, x, padding = patch
_recon_orig = np.pad(recon_orig, patch, mode="constant", constant_values=0) _recon_orig = np.pad(recon_orig, padding, mode="constant", constant_values=0)
_recon_orig = recon_orig[_slice] _recon_orig = _recon_orig[_slice]
_recon_orig = to_grayscale(_recon_orig) _recon_orig = to_grayscale(_recon_orig)
_recon_orig = grayscale_to_rgb(_recon_orig) _recon_orig = grayscale_to_rgb(_recon_orig)
_recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), thickness=1) _recon_orig = cv.rectangle(_recon_orig, (x, y), (x + s, y + s), color=(255, 0, 0), thickness=1)
_recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA) _recon_orig = cv.resize(_recon_orig, (512, 512), cv.INTER_AREA)
_recon = recon[_slice] _recon = recon[_slice]
...@@ -95,7 +98,7 @@ if __name__ == "__main__": ...@@ -95,7 +98,7 @@ if __name__ == "__main__":
_mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA) _mu_map = cv.resize(_mu_map, (512, 512), cv.INTER_AREA)
_mu_map = grayscale_to_rgb(_mu_map) _mu_map = grayscale_to_rgb(_mu_map)
space = np.full((3, 512, 10), 239, np.uint8) space = np.full((512, 10, 3), 239, np.uint8)
return np.hstack((_recon, space, _mu_map, space, _recon_orig)) return np.hstack((_recon, space, _mu_map, space, _recon_orig))
for i in range(len(dataset)): for i in range(len(dataset)):
...@@ -104,18 +107,19 @@ if __name__ == "__main__": ...@@ -104,18 +107,19 @@ if __name__ == "__main__":
patch = dataset.patches[i] patch = dataset.patches[i]
_id, z, y, x, padding = patch _id, z, y, x, padding = patch
print( print(
"Patch {str(i+1):>len(str(len(dataset)))}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[0][0], padding[0][0]}]" f"Patch {str(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - Location [{z:02d}, {y:02d}, {x:02d}] - Padding [{padding[0][0], padding[0][0]}]"
) )
recon, mu_map = dataset[i] recon, mu_map = dataset[i]
recon_orig = dataset.reconstructions[_id] recon_orig = dataset.reconstructions[_id]
cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i)) cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
key = cv.waitKey(100) key = cv.waitKey(100)
while True: while True:
_i = (_i + 1) % recon.shape[0] _i = (_i + 1) % recon.shape[0]
cv.imshow(combine_images(recon, mu_map, recon_orig, patch, _i)) cv.imshow(wname, create_image(recon, mu_map, recon_orig, patch, _i))
key = cv.waitKey(100)
if key == ord("n"): if key == ord("n"):
break break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment