-
Tamino Huxohl authoredTamino Huxohl authored
test.py 1.79 KiB
import cv2 as cv
import matplotlib.pyplot as plt
import numpy as np
import torch
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.mock import MuMapMockDataset
from mu_map.dataset.normalization import norm_max
from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE
torch.set_grad_enabled(False)
dataset = MuMapMockDataset("data/initial/")
model = UNet(in_channels=1, features=[8, 16])
device = torch.device("cpu")
weights = torch.load("train_data/snapshots/10.pth", map_location=device)
model.load_state_dict(weights)
model = model.eval()
recon, mu_map = dataset[0]
recon = recon.unsqueeze(dim=0)
recon = norm_max(recon)
output = model(recon)
# output = output * 40206.0
diff = ((mu_map - output) ** 2).mean()
print(f"Diff: {diff:.3f}")
output = output.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
assert output.shape[0] == mu_map.shape[0]
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText(
_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
)
return _image
def com(image1, image2, _slice):
image1 = to_display_image(image1, _slice)
image2 = to_display_image(image2, _slice)
space = np.full((image1.shape[0], 10), 239, np.uint8)
return np.hstack((image1, space, image2))
i = 0
while True:
x = com(output, mu_map, i)
cv.imshow(wname, x)
key = cv.waitKey(100)
if key == ord("q"):
break
i = (i + 1) % output.shape[0]