From ac333689ee10a16694f2ad68d7fe701650787594 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Wed, 31 Aug 2022 09:56:42 +0200
Subject: [PATCH] update visualization of mu map dataset

---
 mu_map/data/datasets.py | 79 ++++++++++++++++++++++-------------------
 1 file changed, 43 insertions(+), 36 deletions(-)

diff --git a/mu_map/data/datasets.py b/mu_map/data/datasets.py
index 05e9e6c..752ac6a 100644
--- a/mu_map/data/datasets.py
+++ b/mu_map/data/datasets.py
@@ -44,6 +44,7 @@ class MuMapDataset(Dataset):
         images_dir: str = "images",
         bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
         discard_mu_map_slices: bool = True,
+        align: bool = True,
     ):
         super().__init__()
 
@@ -63,6 +64,7 @@ class MuMapDataset(Dataset):
         self.table["id"] = self.table["id"].apply(int)
 
         self.discard_mu_map_slices = discard_mu_map_slices
+        self.align = align
 
     def __getitem__(self, index: int):
         row = self.table.iloc[index]
@@ -81,7 +83,8 @@ class MuMapDataset(Dataset):
             for i in range(mu_map.shape[0]):
                 mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
 
-        recon = align_images(recon, mu_map)
+        if self.align:
+            recon = align_images(recon, mu_map)
 
         return recon, mu_map
 
@@ -92,61 +95,65 @@ class MuMapDataset(Dataset):
 __all__ = [MuMapDataset.__name__]
 
 if __name__ == "__main__":
-    dataset = MuMapDataset("data/tmp")
+    import argparse
 
     import cv2 as cv
+    
+    from mu_map.util import to_grayscale, COLOR_WHITE
 
-    wname = "Images"
+    parser = argparse.ArgumentParser(description="Visualize the images of a MuMapDataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument("dataset_dir", type=str, help="the directory from which the dataset is loaded")
+    parser.add_argument("--unaligned", action="store_true", help="do not perform center alignment of reconstruction an mu-map slices")
+    args = parser.parse_args()
+
+    dataset = MuMapDataset(args.dataset_dir, align=not args.unaligned)
+
+    wname = "Dataset"
     cv.namedWindow(wname, cv.WINDOW_NORMAL)
-    cv.resizeWindow(wname, 1024, 512)
-    space = np.full((128, 10), 239, np.uint8)
+    cv.resizeWindow(wname, 1600, 900)
+    space = np.full((1024, 10), 239, np.uint8)
 
-    def to_grayscale(img: np.ndarray, min_val=None, max_val=None):
-        if min_val is None:
-            min_val = img.min()
+    timeout = 100
 
-        if max_val is None:
-            max_val = img.max()
+    def to_display_image(image, _slice):
+        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
+        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
+        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
+        _image = cv.putText(_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3)
+        return _image
 
-        _img = (img - min_val) / (max_val - min_val)
-        _img = (_img * 255).astype(np.uint8)
-        return _img
+    def combine_images(images, slices):
+        image_1 = to_display_image(images[0], slices[0])
+        image_2 = to_display_image(images[1], slices[1])
+        space = np.full((image_1.shape[0], 10), 239, np.uint8)
+        return np.hstack((image_1, space, image_2))
 
     for i in range(len(dataset)):
         ir = 0
         im = 0
 
         recon, mu_map = dataset[i]
-        print(f"{i+1}/{len(dataset)} - {recon.shape} - {mu_map.shape}")
-
-        to_show = np.hstack(
-            (
-                to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
-                space,
-                to_grayscale(mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()),
-            )
-        )
-        cv.imshow(wname, to_show)
-        key = cv.waitKey(100)
+        print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")
+
+        cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
+        key = cv.waitKey(timeout)
 
         while True:
             ir = (ir + 1) % recon.shape[0]
             im = (im + 1) % mu_map.shape[0]
 
-            to_show = np.hstack(
-                (
-                    to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
-                    space,
-                    to_grayscale(
-                        mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()
-                    ),
-                )
-            )
-            cv.imshow(wname, to_show)
+            cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
 
-            key = cv.waitKey(100)
+            key = cv.waitKey(timeout)
 
             if key == ord("n"):
                 break
-            if key == ord("q"):
+            elif key == ord("q"):
                 exit(0)
+            elif key == ord("p"):
+                timeout = 0 if timeout > 0 else 100
+            elif key == 83: # right arrow key
+                continue
+            elif key == 81: # left arrow key
+                ir = max(ir - 2, 0)
+                im = max(im - 2, 0)
-- 
GitLab