Skip to content
Snippets Groups Projects
Commit 4d245893 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

slight changes to test script for visulization

parent 8dc049f8
No related branches found
No related tags found
No related merge requests found
...@@ -5,29 +5,31 @@ import torch ...@@ -5,29 +5,31 @@ import torch
from mu_map.dataset.default import MuMapDataset from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.mock import MuMapMockDataset from mu_map.dataset.mock import MuMapMockDataset
from mu_map.dataset.normalization import norm_max from mu_map.dataset.normalization import norm_max, norm_gaussian
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE from mu_map.util import to_grayscale, COLOR_WHITE
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
dataset = MuMapMockDataset("data/initial/") dataset = MuMapMockDataset("data/second/")
model = UNet(in_channels=1, features=[8, 16]) # model = UNet(in_channels=1, features=[8, 16])
model = UNet(in_channels=1)
device = torch.device("cpu") device = torch.device("cpu")
weights = torch.load("train_data/snapshots/10.pth", map_location=device) weights = torch.load("train_data/snapshots/val_min_Model.pth", map_location=device)
model.load_state_dict(weights) model.load_state_dict(weights)
model = model.eval() model = model.eval()
recon, mu_map = dataset[0] recon, mu_map = dataset[0]
recon = recon.unsqueeze(dim=0) recon = recon.unsqueeze(dim=0)
recon = norm_max(recon) # recon = norm_max(recon)
recon = norm_gaussian(recon)
output = model(recon) output = model(recon)
# output = output * 40206.0 # output = output * 40206.0
diff = ((mu_map - output) ** 2).mean() diff = ((mu_map - output) ** 2).mean()
print(f"Diff: {diff:.3f}") print(f"Diff: {diff:.5f}")
output = output.squeeze().numpy() output = output.squeeze().numpy()
mu_map = mu_map.squeeze().numpy() mu_map = mu_map.squeeze().numpy()
...@@ -39,8 +41,11 @@ cv.namedWindow(wname, cv.WINDOW_NORMAL) ...@@ -39,8 +41,11 @@ cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900) cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8) space = np.full((1024, 10), 239, np.uint8)
def to_display_image(image, _slice): def to_display_image(image, _slice, _min=None, _max=None):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) _max = _max if _max is not None else image.max()
_min = _min if _min is not None else image.min()
_image = to_grayscale(image[_slice], min_val=_min, max_val=_max)
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText( _image = cv.putText(
...@@ -54,6 +59,7 @@ def com(image1, image2, _slice): ...@@ -54,6 +59,7 @@ def com(image1, image2, _slice):
space = np.full((image1.shape[0], 10), 239, np.uint8) space = np.full((image1.shape[0], 10), 239, np.uint8)
return np.hstack((image1, space, image2)) return np.hstack((image1, space, image2))
output = np.clip(output, 0, mu_map.max())
i = 0 i = 0
while True: while True:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment