import cv2 as cv import matplotlib.pyplot as plt import numpy as np import torch from mu_map.dataset.default import MuMapDataset from mu_map.dataset.mock import MuMapMockDataset from mu_map.dataset.normalization import norm_max from mu_map.models.unet import UNet from mu_map.util import to_grayscale, COLOR_WHITE torch.set_grad_enabled(False) dataset = MuMapMockDataset("data/initial/") model = UNet(in_channels=1, features=[8, 16]) device = torch.device("cpu") weights = torch.load("tmp/10.pth", map_location=device) model.load_state_dict(weights) model = model.eval() recon, mu_map = dataset[0] recon = recon.unsqueeze(dim=0) recon = norm_max(recon) output = model(recon) output = output * 40206.0 diff = ((mu_map - output) ** 2).mean() print(f"Diff: {diff:.3f}") output = output.squeeze().numpy() mu_map = mu_map.squeeze().numpy() assert output.shape[0] == mu_map.shape[0] wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) space = np.full((1024, 10), 239, np.uint8) def to_display_image(image, _slice): _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" _image = cv.putText( _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 ) return _image def com(image1, image2, _slice): image1 = to_display_image(image1, _slice) image2 = to_display_image(image2, _slice) space = np.full((image1.shape[0], 10), 239, np.uint8) return np.hstack((image1, space, image2)) i = 0 while True: x = com(output, mu_map, i) cv.imshow(wname, x) key = cv.waitKey(100) if key == ord("q"): break i = (i + 1) % output.shape[0] # dataset = MuMapDataset("data/initial") # # print(" Recon || MuMap") # # print(" Min | Max | Average || Min | Max | Average") # r_max = [] # r_avg = [] # r_max_p = [] # r_avg_p = [] # r_avg_x = [] # m_max = [] # for recon, mu_map in dataset: # r_max.append(recon.max()) # r_avg.append(recon.mean()) # recon_p = recon[:, :, 16:112, 16:112] # r_max_p.append(recon_p.max()) # r_avg_p.append(recon_p.mean()) # r_avg_x.append(recon.sum() / (recon > 0.0).sum()) # # r_min = f"{recon.min():5.3f}" # # r_max = f"{recon.max():5.3f}" # # r_avg = f"{recon.mean():5.3f}" # # m_min = f"{mu_map.min():5.3f}" # # m_max = f"{mu_map.max():5.3f}" # # m_avg = f"{mu_map.mean():5.3f}" # # print(f"{r_min} | {r_max} | {r_avg} || {m_min} | {m_max} | {m_avg}") # m_max.append(mu_map.max()) # # print(mu_map.max()) # r_max = np.array(r_max) # r_avg = np.array(r_avg) # r_max_p = np.array(r_max_p) # r_avg_p = np.array(r_avg_p) # r_avg_x = np.array(r_avg_x) # m_max = np.array(m_max) # fig, ax = plt.subplots() # ax.scatter(r_max, m_max) # # fig, axs = plt.subplots(4, 3, figsize=(16, 12)) # # axs[0, 0].hist(r_max) # # axs[0, 0].set_title("Max") # # axs[1, 0].hist(r_avg) # # axs[1, 0].set_title("Mean") # # axs[2, 0].hist(r_max / r_avg) # # axs[2, 0].set_title("Max / Mean") # # axs[3, 0].hist(recon.flatten()) # # axs[3, 0].set_title("Example Reconstruction") # # axs[0, 1].hist(r_max_p) # # axs[0, 1].set_title("Max") # # axs[1, 1].hist(r_avg_p) # # axs[1, 1].set_title("Mean") # # axs[2, 1].hist(r_max_p / r_avg_p) # # axs[2, 1].set_title("Max / Mean") # # axs[3, 1].hist(recon_p.flatten()) # # axs[3, 1].set_title("Example Reconstruction") # # axs[0, 2].hist(r_max_p) # # axs[0, 2].set_title("Max") # # axs[1, 2].hist(r_avg_x) # # axs[1, 2].set_title("Mean") # # axs[2, 2].hist(r_max_p / r_avg_x) # # axs[2, 2].set_title("Max / Mean") # # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0))) # # axs[3, 2].set_title("Example Reconstruction") # plt.show()