Skip to content
Snippets Groups Projects
Commit ff167f5d authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

improve random_search show predicion so that it is usable from other scripts

parent 0b357a4b
No related branches found
No related tags found
No related merge requests found
from typing import Callable
import cv2 as cv
import numpy as np
import torch
torch.set_grad_enabled(False)
from mu_map.data.prepare import headers
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import (
GaussianNormTransform,
......@@ -18,25 +21,49 @@ from mu_map.vis.slices import join_images
from mu_map.random_search.cgan import load_params
def main(model: torch.nn.Module, dataset: MuMapDataset):
def main(
model: torch.nn.Module,
dataset: MuMapDataset,
wname: str = "Prediction",
action: Callable[int, bool] = None,
_print: bool = True,
):
"""
Visualize the predictions of a model for all reconstructions
in a dataset.
Parameters
----------
model: torch.nn.Module
the mode with which predictions are computed
dataset: MuMapDataset
the dataset containing reconstructions for which images are computed and target attenuation maps for comparison
wname: str
the name of the display window
action: Callable[int, bool]
Add control behaviour by providing a callable reacting to key presses.
If it returns true, the display is stopped.
_print: bool
if measures for predictions should be printed
"""
timeout = 100
print(" Id | NMAE | MSE")
print("----|----------|---------")
# disable print if not wanted
print_func = print if _print else lambda x: x
# print header
print_func(" Id | NMAE | MSE")
print_func("----|----------|---------")
for i, (recon, mu_map) in enumerate(dataset):
_id = dataset.table.iloc[i]["id"]
_id = dataset.table.iloc[i][headers.id]
prediction = model(recon.unsqueeze(dim=0)).squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
_nmae = nmae(prediction, mu_map)
_mse = mse(prediction, mu_map)
print(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}")
print_func(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}")
prediction = np.clip(prediction, 0, prediction.max())
diff = np.abs(prediction - mu_map)
......@@ -47,6 +74,7 @@ def main(model: torch.nn.Module, dataset: MuMapDataset):
n_slices = mu_map.shape[0]
_slice = 0
_break_outer = False
while True:
images = map(lambda v: v[_slice], volumes)
images = map(
......@@ -62,6 +90,10 @@ def main(model: torch.nn.Module, dataset: MuMapDataset):
cv.imshow(wname, join_images(images))
key = cv.waitKey(100)
if action is not None and action(key):
_break_outer = True
break
if key == ord("q"):
exit(0)
elif key == ord("n"):
......@@ -71,6 +103,9 @@ def main(model: torch.nn.Module, dataset: MuMapDataset):
elif key == ord("p"):
timeout = 0 if timeout > 0 else 100
if _break_outer:
break
if __name__ == "__main__":
import argparse
......@@ -109,8 +144,8 @@ if __name__ == "__main__":
model.load_state_dict(weights)
model = model.to(device).eval()
wname = "Dataset"
wname = "Prediction"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
main(model, dataset)
main(model, dataset, wname=wname)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment