Skip to content
Snippets Groups Projects
Commit d66fd526 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement visualization of results in test script

parent 9beb28b1
No related branches found
No related tags found
No related merge requests found
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import torch
from .data.preprocessing import *
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.mock import MuMapMockDataset
from mu_map.dataset.normalization import norm_max
from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE
means = torch.full((10, 10, 10), 5.0)
stds = torch.full((10, 10, 10), 10.0)
x = torch.normal(means, stds)
torch.set_grad_enabled(False)
print(f"Before: mean={x.mean():.3f} std={x.std():.3f}")
dataset = MuMapMockDataset("data/initial/")
y = norm_gaussian(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
y = GaussianNorm()(x)
print(f" After: mean={y.mean():.3f} std={y.std():.3f}")
model = UNet(in_channels=1, features=[8, 16])
device = torch.device("cpu")
weights = torch.load("tmp/10.pth", map_location=device)
model.load_state_dict(weights)
model = model.eval()
recon, mu_map = dataset[0]
recon = recon.unsqueeze(dim=0)
recon = norm_max(recon)
import cv2 as cv
import numpy as np
output = model(recon)
output = output * 40206.0
diff = ((mu_map - output) ** 2).mean()
print(f"Diff: {diff:.3f}")
output = output.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
assert output.shape[0] == mu_map.shape[0]
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText(
_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
)
return _image
def com(image1, image2, _slice):
image1 = to_display_image(image1, _slice)
image2 = to_display_image(image2, _slice)
space = np.full((image1.shape[0], 10), 239, np.uint8)
return np.hstack((image1, space, image2))
i = 0
while True:
x = com(output, mu_map, i)
cv.imshow(wname, x)
key = cv.waitKey(100)
if key == ord("q"):
break
i = (i + 1) % output.shape[0]
# dataset = MuMapDataset("data/initial")
# # print(" Recon || MuMap")
# # print(" Min | Max | Average || Min | Max | Average")
# r_max = []
# r_avg = []
# r_max_p = []
# r_avg_p = []
# r_avg_x = []
# m_max = []
# for recon, mu_map in dataset:
# r_max.append(recon.max())
# r_avg.append(recon.mean())
# recon_p = recon[:, :, 16:112, 16:112]
# r_max_p.append(recon_p.max())
# r_avg_p.append(recon_p.mean())
# r_avg_x.append(recon.sum() / (recon > 0.0).sum())
# # r_min = f"{recon.min():5.3f}"
# # r_max = f"{recon.max():5.3f}"
# # r_avg = f"{recon.mean():5.3f}"
# # m_min = f"{mu_map.min():5.3f}"
# # m_max = f"{mu_map.max():5.3f}"
# # m_avg = f"{mu_map.mean():5.3f}"
# # print(f"{r_min} | {r_max} | {r_avg} || {m_min} | {m_max} | {m_avg}")
# m_max.append(mu_map.max())
# # print(mu_map.max())
# r_max = np.array(r_max)
# r_avg = np.array(r_avg)
# r_max_p = np.array(r_max_p)
# r_avg_p = np.array(r_avg_p)
# r_avg_x = np.array(r_avg_x)
# m_max = np.array(m_max)
# fig, ax = plt.subplots()
# ax.scatter(r_max, m_max)
# # fig, axs = plt.subplots(4, 3, figsize=(16, 12))
# # axs[0, 0].hist(r_max)
# # axs[0, 0].set_title("Max")
# # axs[1, 0].hist(r_avg)
# # axs[1, 0].set_title("Mean")
# # axs[2, 0].hist(r_max / r_avg)
# # axs[2, 0].set_title("Max / Mean")
# # axs[3, 0].hist(recon.flatten())
# # axs[3, 0].set_title("Example Reconstruction")
# # axs[0, 1].hist(r_max_p)
# # axs[0, 1].set_title("Max")
# # axs[1, 1].hist(r_avg_p)
# # axs[1, 1].set_title("Mean")
# # axs[2, 1].hist(r_max_p / r_avg_p)
# # axs[2, 1].set_title("Max / Mean")
# # axs[3, 1].hist(recon_p.flatten())
# # axs[3, 1].set_title("Example Reconstruction")
x = np.zeros((512, 512), np.uint8)
cv.imshow("X", x)
key = cv.waitKey(0)
while key != ord("q"):
print(key)
key = cv.waitKey(0)
# # axs[0, 2].hist(r_max_p)
# # axs[0, 2].set_title("Max")
# # axs[1, 2].hist(r_avg_x)
# # axs[1, 2].set_title("Mean")
# # axs[2, 2].hist(r_max_p / r_avg_x)
# # axs[2, 2].set_title("Max / Mean")
# # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0)))
# # axs[3, 2].set_title("Example Reconstruction")
# plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment