diff --git a/mu_map/data/mock.py b/mu_map/data/mock.py new file mode 100644 index 0000000000000000000000000000000000000000..2bb2b9ae60b3506232b5a79061fe6ea3c7071fa6 --- /dev/null +++ b/mu_map/data/mock.py @@ -0,0 +1,82 @@ +from mu_map.data.datasets import MuMapDataset + + +class MuMapMockDataset(MuMapDataset): + def __init__(self, dataset_dir: str = "data/initial/", num_images: int = 16): + super().__init__(dataset_dir=dataset_dir) + self.len = num_images + + def __getitem__(self, index: int): + recon, mu_map = super().__getitem__(0) + recon = recon[:, :32, :, :] + mu_map = mu_map[:, :32, :, :] + return recon, mu_map + + def __len__(self): + return self.len + + +if __name__ == "__main__": + import cv2 as cv + import numpy as np + + from mu_map.util import to_grayscale, COLOR_WHITE + + dataset = MuMapMockDataset() + + wname = "Dataset" + cv.namedWindow(wname, cv.WINDOW_NORMAL) + cv.resizeWindow(wname, 1600, 900) + space = np.full((1024, 10), 239, np.uint8) + + timeout = 100 + + def to_display_image(image, _slice): + _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) + _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) + _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" + _image = cv.putText( + _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 + ) + return _image + + def combine_images(images, slices): + image_1 = to_display_image(images[0], slices[0]) + image_2 = to_display_image(images[1], slices[1]) + space = np.full((image_1.shape[0], 10), 239, np.uint8) + return np.hstack((image_1, space, image_2)) + + for i in range(len(dataset)): + ir = 0 + im = 0 + + recon, mu_map = dataset[i] + print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r") + + cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) + key = cv.waitKey(timeout) + + running = 0 + while True: + ir = (ir + 1) % recon.shape[0] + im = (im + 1) % mu_map.shape[0] + + to_show = combine_images((recon, mu_map), (ir, im)) + cv.imshow(wname, to_show) + + key = cv.waitKey(timeout) + + if key == ord("n"): + break + elif key == ord("q"): + exit(0) + elif key == ord("p"): + timeout = 0 if timeout > 0 else 100 + elif key == 83: # right arrow key + continue + elif key == 81: # left arrow key + ir = max(ir - 2, 0) + im = max(im - 2, 0) + elif key == ord("s"): + cv.imwrite(f"{running:03d}.png", to_show) + running += 1