From d66fd526237e04edd238979332ace4ddb08c21e1 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 27 Sep 2022 13:35:54 +0200 Subject: [PATCH] implement visualization of results in test script --- mu_map/test.py | 151 +++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 134 insertions(+), 17 deletions(-) diff --git a/mu_map/test.py b/mu_map/test.py index 1da35df..33c04b2 100644 --- a/mu_map/test.py +++ b/mu_map/test.py @@ -1,26 +1,143 @@ +import cv2 as cv +import matplotlib.pyplot as plt +import numpy as np import torch -from .data.preprocessing import * +from mu_map.dataset.default import MuMapDataset +from mu_map.dataset.mock import MuMapMockDataset +from mu_map.dataset.normalization import norm_max +from mu_map.models.unet import UNet +from mu_map.util import to_grayscale, COLOR_WHITE -means = torch.full((10, 10, 10), 5.0) -stds = torch.full((10, 10, 10), 10.0) -x = torch.normal(means, stds) +torch.set_grad_enabled(False) -print(f"Before: mean={x.mean():.3f} std={x.std():.3f}") +dataset = MuMapMockDataset("data/initial/") -y = norm_gaussian(x) -print(f" After: mean={y.mean():.3f} std={y.std():.3f}") -y = GaussianNorm()(x) -print(f" After: mean={y.mean():.3f} std={y.std():.3f}") +model = UNet(in_channels=1, features=[8, 16]) +device = torch.device("cpu") +weights = torch.load("tmp/10.pth", map_location=device) +model.load_state_dict(weights) +model = model.eval() +recon, mu_map = dataset[0] +recon = recon.unsqueeze(dim=0) +recon = norm_max(recon) -import cv2 as cv -import numpy as np +output = model(recon) +output = output * 40206.0 + +diff = ((mu_map - output) ** 2).mean() +print(f"Diff: {diff:.3f}") + +output = output.squeeze().numpy() +mu_map = mu_map.squeeze().numpy() + +assert output.shape[0] == mu_map.shape[0] + +wname = "Dataset" +cv.namedWindow(wname, cv.WINDOW_NORMAL) +cv.resizeWindow(wname, 1600, 900) +space = np.full((1024, 10), 239, np.uint8) + +def to_display_image(image, _slice): + _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) + _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) + _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" + _image = cv.putText( + _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 + ) + return _image + +def com(image1, image2, _slice): + image1 = to_display_image(image1, _slice) + image2 = to_display_image(image2, _slice) + space = np.full((image1.shape[0], 10), 239, np.uint8) + return np.hstack((image1, space, image2)) + + +i = 0 +while True: + x = com(output, mu_map, i) + cv.imshow(wname, x) + key = cv.waitKey(100) + + if key == ord("q"): + break + + i = (i + 1) % output.shape[0] + + + + +# dataset = MuMapDataset("data/initial") + +# # print(" Recon || MuMap") +# # print(" Min | Max | Average || Min | Max | Average") +# r_max = [] +# r_avg = [] + +# r_max_p = [] +# r_avg_p = [] + +# r_avg_x = [] + +# m_max = [] +# for recon, mu_map in dataset: + # r_max.append(recon.max()) + # r_avg.append(recon.mean()) + + # recon_p = recon[:, :, 16:112, 16:112] + # r_max_p.append(recon_p.max()) + # r_avg_p.append(recon_p.mean()) + + # r_avg_x.append(recon.sum() / (recon > 0.0).sum()) + # # r_min = f"{recon.min():5.3f}" + # # r_max = f"{recon.max():5.3f}" + # # r_avg = f"{recon.mean():5.3f}" + # # m_min = f"{mu_map.min():5.3f}" + # # m_max = f"{mu_map.max():5.3f}" + # # m_avg = f"{mu_map.mean():5.3f}" + # # print(f"{r_min} | {r_max} | {r_avg} || {m_min} | {m_max} | {m_avg}") + # m_max.append(mu_map.max()) + # # print(mu_map.max()) +# r_max = np.array(r_max) +# r_avg = np.array(r_avg) + +# r_max_p = np.array(r_max_p) +# r_avg_p = np.array(r_avg_p) + +# r_avg_x = np.array(r_avg_x) + +# m_max = np.array(m_max) +# fig, ax = plt.subplots() +# ax.scatter(r_max, m_max) + +# # fig, axs = plt.subplots(4, 3, figsize=(16, 12)) +# # axs[0, 0].hist(r_max) +# # axs[0, 0].set_title("Max") +# # axs[1, 0].hist(r_avg) +# # axs[1, 0].set_title("Mean") +# # axs[2, 0].hist(r_max / r_avg) +# # axs[2, 0].set_title("Max / Mean") +# # axs[3, 0].hist(recon.flatten()) +# # axs[3, 0].set_title("Example Reconstruction") + +# # axs[0, 1].hist(r_max_p) +# # axs[0, 1].set_title("Max") +# # axs[1, 1].hist(r_avg_p) +# # axs[1, 1].set_title("Mean") +# # axs[2, 1].hist(r_max_p / r_avg_p) +# # axs[2, 1].set_title("Max / Mean") +# # axs[3, 1].hist(recon_p.flatten()) +# # axs[3, 1].set_title("Example Reconstruction") -x = np.zeros((512, 512), np.uint8) -cv.imshow("X", x) -key = cv.waitKey(0) -while key != ord("q"): - print(key) - key = cv.waitKey(0) +# # axs[0, 2].hist(r_max_p) +# # axs[0, 2].set_title("Max") +# # axs[1, 2].hist(r_avg_x) +# # axs[1, 2].set_title("Mean") +# # axs[2, 2].hist(r_max_p / r_avg_x) +# # axs[2, 2].set_title("Max / Mean") +# # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0))) +# # axs[3, 2].set_title("Example Reconstruction") +# plt.show() -- GitLab