Skip to content
Snippets Groups Projects
test.py 3.79 KiB
Newer Older
  • Learn to ignore specific revisions
  • import cv2 as cv
    import matplotlib.pyplot as plt
    import numpy as np
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    import torch
    
    
    from mu_map.dataset.default import MuMapDataset
    from mu_map.dataset.mock import MuMapMockDataset
    from mu_map.dataset.normalization import norm_max
    from mu_map.models.unet import UNet
    from mu_map.util import to_grayscale, COLOR_WHITE
    
    torch.set_grad_enabled(False)
    
    dataset = MuMapMockDataset("data/initial/")
    
    model = UNet(in_channels=1, features=[8, 16])
    device = torch.device("cpu")
    weights = torch.load("tmp/10.pth", map_location=device)
    model.load_state_dict(weights)
    model = model.eval()
    
    recon, mu_map = dataset[0]
    recon = recon.unsqueeze(dim=0)
    recon = norm_max(recon)
    
    output = model(recon)
    output = output * 40206.0
    
    diff = ((mu_map - output) ** 2).mean()
    print(f"Diff: {diff:.3f}")
    
    output = output.squeeze().numpy()
    mu_map = mu_map.squeeze().numpy()
    
    assert output.shape[0] == mu_map.shape[0]
    
    wname = "Dataset"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    space = np.full((1024, 10), 239, np.uint8)
    
    def to_display_image(image, _slice):
        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
        _image = cv.putText(
            _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
        )
        return _image
    
    def com(image1, image2, _slice):
        image1 = to_display_image(image1, _slice)
        image2 = to_display_image(image2, _slice)
        space = np.full((image1.shape[0], 10), 239, np.uint8)
        return np.hstack((image1, space, image2))
    
    
    i = 0
    while True:
        x = com(output, mu_map, i)
        cv.imshow(wname, x)
        key = cv.waitKey(100)
    
        if key == ord("q"):
            break
    
        i = (i + 1) % output.shape[0]
    
    
    
    
    # dataset = MuMapDataset("data/initial")
    
    # # print("                Recon ||                MuMap")
    # # print("     Min |      Max | Average ||     Min |      Max | Average")
    # r_max = []
    # r_avg = []
    
    # r_max_p = []
    # r_avg_p = []
    
    # r_avg_x = []
    
    # m_max = []
    # for recon, mu_map in dataset:
        # r_max.append(recon.max())
        # r_avg.append(recon.mean())
    
        # recon_p = recon[:, :, 16:112, 16:112]
        # r_max_p.append(recon_p.max())
        # r_avg_p.append(recon_p.mean())
    
        # r_avg_x.append(recon.sum() / (recon > 0.0).sum())
        # # r_min = f"{recon.min():5.3f}"
        # # r_max = f"{recon.max():5.3f}"
        # # r_avg = f"{recon.mean():5.3f}"
        # # m_min = f"{mu_map.min():5.3f}"
        # # m_max = f"{mu_map.max():5.3f}"
        # # m_avg = f"{mu_map.mean():5.3f}"
        # # print(f"{r_min} | {r_max} |    {r_avg} || {m_min} | {m_max} |    {m_avg}")
        # m_max.append(mu_map.max())
        # # print(mu_map.max())
    # r_max = np.array(r_max)
    # r_avg = np.array(r_avg)
    
    # r_max_p = np.array(r_max_p)
    # r_avg_p = np.array(r_avg_p)
    
    # r_avg_x = np.array(r_avg_x)
    
    # m_max = np.array(m_max)
    # fig, ax = plt.subplots()
    # ax.scatter(r_max, m_max)
    
    # # fig, axs = plt.subplots(4, 3, figsize=(16, 12))
    # # axs[0, 0].hist(r_max)
    # # axs[0, 0].set_title("Max")
    # # axs[1, 0].hist(r_avg)
    # # axs[1, 0].set_title("Mean")
    # # axs[2, 0].hist(r_max / r_avg)
    # # axs[2, 0].set_title("Max / Mean")
    # # axs[3, 0].hist(recon.flatten())
    # # axs[3, 0].set_title("Example Reconstruction")
    
    # # axs[0, 1].hist(r_max_p)
    # # axs[0, 1].set_title("Max")
    # # axs[1, 1].hist(r_avg_p)
    # # axs[1, 1].set_title("Mean")
    # # axs[2, 1].hist(r_max_p / r_avg_p)
    # # axs[2, 1].set_title("Max / Mean")
    # # axs[3, 1].hist(recon_p.flatten())
    # # axs[3, 1].set_title("Example Reconstruction")
    
    # # axs[0, 2].hist(r_max_p)
    # # axs[0, 2].set_title("Max")
    # # axs[1, 2].hist(r_avg_x)
    # # axs[1, 2].set_title("Mean")
    # # axs[2, 2].hist(r_max_p / r_avg_x)
    # # axs[2, 2].set_title("Max / Mean")
    # # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0)))
    # # axs[3, 2].set_title("Example Reconstruction")