Newer
Older

Tamino Huxohl
committed
from typing import Callable
import cv2 as cv
import numpy as np
import torch
torch.set_grad_enabled(False)

Tamino Huxohl
committed
from mu_map.data.prepare import headers
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import (
GaussianNormTransform,
MeanNormTransform,
MaxNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, SequenceTransform
from mu_map.eval.measures import nmae, mse
from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE
from mu_map.vis.slices import join_images
from mu_map.random_search.cgan import load_params

Tamino Huxohl
committed
def main(
model: torch.nn.Module,
dataset: MuMapDataset,
wname: str = "Prediction",
action: Callable[int, bool] = None,
_print: bool = True,
):
"""
Visualize the predictions of a model for all reconstructions
in a dataset.

Tamino Huxohl
committed
Parameters
----------
model: torch.nn.Module
the mode with which predictions are computed
dataset: MuMapDataset
the dataset containing reconstructions for which images are computed and target attenuation maps for comparison
wname: str
the name of the display window
action: Callable[int, bool]
Add control behaviour by providing a callable reacting to key presses.
If it returns true, the display is stopped.
_print: bool
if measures for predictions should be printed
"""
timeout = 100

Tamino Huxohl
committed
# disable print if not wanted
print_func = print if _print else lambda x: x
# print header
print_func(" Id | NMAE | MSE")
print_func("----|----------|---------")
for i, (recon, mu_map) in enumerate(dataset):

Tamino Huxohl
committed
_id = dataset.table.iloc[i][headers.id]
prediction = model(recon.unsqueeze(dim=0)).squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
_nmae = nmae(prediction, mu_map)
_mse = mse(prediction, mu_map)

Tamino Huxohl
committed
print_func(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}")
prediction = np.clip(prediction, 0, mu_map.max())
diff = np.abs(prediction - mu_map)
volumes = [prediction, mu_map, diff]
min_val = 0
max_val = mu_map.max()
n_slices = mu_map.shape[0]
_slice = 0

Tamino Huxohl
committed
_break_outer = False
timeout = 100
show_text = True
while True:
images = map(lambda v: v[_slice], volumes)
images = map(
lambda img: to_grayscale(img, min_val=min_val, max_val=max_val), images
)
images = map(lambda img: cv.resize(img, (512, 512)), images)
images = list(images)
if show_text:
txt = f"{str(_slice):{len(str(n_slices))}}/{n_slices}"
cv.putText(images[0], txt, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, 255, 3)
_slice = (_slice + 1) % n_slices
cv.imshow(wname, join_images(images))
key = cv.waitKey(timeout)

Tamino Huxohl
committed
if action is not None and action(key):
_break_outer = True
break
if key == ord("q"):
exit(0)
elif key == ord("n"):
break
elif key == 81:
_slice = (_slice - 2) % n_slices
elif key == ord("p"):
timeout = 0 if timeout > 0 else 100
elif key == ord("t"):
show_text = not show_text
_slice = (_slice - 1) % n_slices
elif key == ord("s"):
_slice = (_slice - 1) % n_slices
cv.imwrite("prediction.png", images[0])
cv.imwrite("mu_map.png", images[1])
cv.imwrite("difference.png", images[2])
with open("info.txt", mode="w") as f:
f.write(f"Arguments: {args}\n")
f.write("\n")
f.write(f"Id: {_id}\n")
f.write(f"Slice: {_slice}\n")

Tamino Huxohl
committed
if _break_outer:
break
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
if __name__ == "__main__":
import argparse
import os
parser = argparse.ArgumentParser(
description="visualize the results of a random search run",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dir", type=str, help="directory containing the data of the random search run"
)
parser.add_argument(
"--split",
choices=["train", "validation", "test"],
default="validation",
help="the split of the dataset used",
)
args = parser.parse_args()
params = load_params(os.path.join(args.dir, "params.json"))
dataset = MuMapDataset(
"data/second/",
transform_normalization=SequenceTransform(
[params["normalization"], PadCropTranform(dim=3, size=32)]
),
split_name=args.split,
scatter_correction=False,
)
device = torch.device("cpu")
model = UNet(features=params["generator_features"])
weights = torch.load(
os.path.join(args.dir, "snapshots", "val_min_generator.pth"),
map_location=device,
)
model.load_state_dict(weights)
model = model.to(device).eval()

Tamino Huxohl
committed
wname = "Prediction"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)

Tamino Huxohl
committed
main(model, dataset, wname=wname)