Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import cv2 as cv
import numpy as np
import torch
torch.set_grad_enabled(False)
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import (
GaussianNormTransform,
MeanNormTransform,
MaxNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, SequenceTransform
from mu_map.eval.measures import nmae, mse
from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE
from mu_map.vis.slices import join_images
from mu_map.random_search.cgan import load_params
def main(model: torch.nn.Module, dataset: MuMapDataset):
"""
Visualize the predictions of a model for all reconstructions
in a dataset.
"""
timeout = 100
print(" Id | NMAE | MSE")
print("----|----------|---------")
for i, (recon, mu_map) in enumerate(dataset):
_id = dataset.table.iloc[i]["id"]
prediction = model(recon.unsqueeze(dim=0)).squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
_nmae = nmae(prediction, mu_map)
_mse = mse(prediction, mu_map)
print(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}")
prediction = np.clip(prediction, 0, prediction.max())
diff = np.abs(prediction - mu_map)
volumes = [prediction, mu_map, diff]
min_val = 0
max_val = mu_map.max()
n_slices = mu_map.shape[0]
_slice = 0
while True:
images = map(lambda v: v[_slice], volumes)
images = map(
lambda img: to_grayscale(img, min_val=min_val, max_val=max_val), images
)
images = map(lambda img: cv.resize(img, (512, 512)), images)
images = list(images)
txt = f"{str(_slice):{len(str(n_slices))}}/{n_slices}"
cv.putText(images[0], txt, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, 255, 3)
_slice = (_slice + 1) % n_slices
cv.imshow(wname, join_images(images))
key = cv.waitKey(100)
if key == ord("q"):
exit(0)
elif key == ord("n"):
break
elif key == 81:
_slice = (_slice - 2) % n_slices
elif key == ord("p"):
timeout = 0 if timeout > 0 else 100
if __name__ == "__main__":
import argparse
import os
parser = argparse.ArgumentParser(
description="visualize the results of a random search run",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dir", type=str, help="directory containing the data of the random search run"
)
parser.add_argument(
"--split",
choices=["train", "validation", "test"],
default="validation",
help="the split of the dataset used",
)
args = parser.parse_args()
params = load_params(os.path.join(args.dir, "params.json"))
dataset = MuMapDataset(
"data/second/",
transform_normalization=SequenceTransform(
[params["normalization"], PadCropTranform(dim=3, size=32)]
),
split_name=args.split,
scatter_correction=False,
)
device = torch.device("cpu")
model = UNet(features=params["generator_features"])
weights = torch.load(
os.path.join(args.dir, "snapshots", "val_min_generator.pth"),
map_location=device,
)
model.load_state_dict(weights)
model = model.to(device).eval()
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
main(model, dataset)