From d66fd526237e04edd238979332ace4ddb08c21e1 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 27 Sep 2022 13:35:54 +0200
Subject: [PATCH] implement visualization of results in test script

---
 mu_map/test.py | 151 +++++++++++++++++++++++++++++++++++++++++++------
 1 file changed, 134 insertions(+), 17 deletions(-)

diff --git a/mu_map/test.py b/mu_map/test.py
index 1da35df..33c04b2 100644
--- a/mu_map/test.py
+++ b/mu_map/test.py
@@ -1,26 +1,143 @@
+import cv2 as cv
+import matplotlib.pyplot as plt
+import numpy as np
 import torch
 
-from .data.preprocessing import *
+from mu_map.dataset.default import MuMapDataset
+from mu_map.dataset.mock import MuMapMockDataset
+from mu_map.dataset.normalization import norm_max
+from mu_map.models.unet import UNet
+from mu_map.util import to_grayscale, COLOR_WHITE
 
-means = torch.full((10, 10, 10), 5.0)
-stds = torch.full((10, 10, 10), 10.0)
-x = torch.normal(means, stds)
+torch.set_grad_enabled(False)
 
-print(f"Before: mean={x.mean():.3f} std={x.std():.3f}")
+dataset = MuMapMockDataset("data/initial/")
 
-y = norm_gaussian(x)
-print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
-y = GaussianNorm()(x)
-print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
+model = UNet(in_channels=1, features=[8, 16])
+device = torch.device("cpu")
+weights = torch.load("tmp/10.pth", map_location=device)
+model.load_state_dict(weights)
+model = model.eval()
 
+recon, mu_map = dataset[0]
+recon = recon.unsqueeze(dim=0)
+recon = norm_max(recon)
 
-import cv2 as cv
-import numpy as np
+output = model(recon)
+output = output * 40206.0
+
+diff = ((mu_map - output) ** 2).mean()
+print(f"Diff: {diff:.3f}")
+
+output = output.squeeze().numpy()
+mu_map = mu_map.squeeze().numpy()
+
+assert output.shape[0] == mu_map.shape[0]
+
+wname = "Dataset"
+cv.namedWindow(wname, cv.WINDOW_NORMAL)
+cv.resizeWindow(wname, 1600, 900)
+space = np.full((1024, 10), 239, np.uint8)
+
+def to_display_image(image, _slice):
+    _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
+    _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
+    _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
+    _image = cv.putText(
+        _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
+    )
+    return _image
+
+def com(image1, image2, _slice):
+    image1 = to_display_image(image1, _slice)
+    image2 = to_display_image(image2, _slice)
+    space = np.full((image1.shape[0], 10), 239, np.uint8)
+    return np.hstack((image1, space, image2))
+
+
+i = 0
+while True:
+    x = com(output, mu_map, i)
+    cv.imshow(wname, x)
+    key = cv.waitKey(100)
+
+    if key == ord("q"):
+        break
+
+    i = (i + 1) % output.shape[0]
+
+
+
+
+# dataset = MuMapDataset("data/initial")
+
+# # print("                Recon ||                MuMap")
+# # print("     Min |      Max | Average ||     Min |      Max | Average")
+# r_max = []
+# r_avg = []
+
+# r_max_p = []
+# r_avg_p = []
+
+# r_avg_x = []
+
+# m_max = []
+# for recon, mu_map in dataset:
+    # r_max.append(recon.max())
+    # r_avg.append(recon.mean())
+
+    # recon_p = recon[:, :, 16:112, 16:112]
+    # r_max_p.append(recon_p.max())
+    # r_avg_p.append(recon_p.mean())
+
+    # r_avg_x.append(recon.sum() / (recon > 0.0).sum())
+    # # r_min = f"{recon.min():5.3f}"
+    # # r_max = f"{recon.max():5.3f}"
+    # # r_avg = f"{recon.mean():5.3f}"
+    # # m_min = f"{mu_map.min():5.3f}"
+    # # m_max = f"{mu_map.max():5.3f}"
+    # # m_avg = f"{mu_map.mean():5.3f}"
+    # # print(f"{r_min} | {r_max} |    {r_avg} || {m_min} | {m_max} |    {m_avg}")
+    # m_max.append(mu_map.max())
+    # # print(mu_map.max())
+# r_max = np.array(r_max)
+# r_avg = np.array(r_avg)
+
+# r_max_p = np.array(r_max_p)
+# r_avg_p = np.array(r_avg_p)
+
+# r_avg_x = np.array(r_avg_x)
+
+# m_max = np.array(m_max)
+# fig, ax = plt.subplots()
+# ax.scatter(r_max, m_max)
+
+# # fig, axs = plt.subplots(4, 3, figsize=(16, 12))
+# # axs[0, 0].hist(r_max)
+# # axs[0, 0].set_title("Max")
+# # axs[1, 0].hist(r_avg)
+# # axs[1, 0].set_title("Mean")
+# # axs[2, 0].hist(r_max / r_avg)
+# # axs[2, 0].set_title("Max / Mean")
+# # axs[3, 0].hist(recon.flatten())
+# # axs[3, 0].set_title("Example Reconstruction")
+
+# # axs[0, 1].hist(r_max_p)
+# # axs[0, 1].set_title("Max")
+# # axs[1, 1].hist(r_avg_p)
+# # axs[1, 1].set_title("Mean")
+# # axs[2, 1].hist(r_max_p / r_avg_p)
+# # axs[2, 1].set_title("Max / Mean")
+# # axs[3, 1].hist(recon_p.flatten())
+# # axs[3, 1].set_title("Example Reconstruction")
 
-x = np.zeros((512, 512), np.uint8)
-cv.imshow("X", x)
-key = cv.waitKey(0)
-while key != ord("q"):
-    print(key)
-    key = cv.waitKey(0)
+# # axs[0, 2].hist(r_max_p)
+# # axs[0, 2].set_title("Max")
+# # axs[1, 2].hist(r_avg_x)
+# # axs[1, 2].set_title("Mean")
+# # axs[2, 2].hist(r_max_p / r_avg_x)
+# # axs[2, 2].set_title("Max / Mean")
+# # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0)))
+# # axs[3, 2].set_title("Example Reconstruction")
 
+# plt.show()
-- 
GitLab