from typing import Callable import cv2 as cv import numpy as np import torch torch.set_grad_enabled(False) from import headers from mu_map.dataset.default import MuMapDataset from mu_map.dataset.normalization import ( GaussianNormTransform, MeanNormTransform, MaxNormTransform, ) from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.eval.measures import nmae, mse from mu_map.models.unet import UNet from mu_map.util import to_grayscale, COLOR_WHITE from mu_map.vis.slices import join_images from mu_map.random_search.cgan import load_params def main( model: torch.nn.Module, dataset: MuMapDataset, wname: str = "Prediction", action: Callable[int, bool] = None, _print: bool = True, ): """ Visualize the predictions of a model for all reconstructions in a dataset. Parameters ---------- model: torch.nn.Module the mode with which predictions are computed dataset: MuMapDataset the dataset containing reconstructions for which images are computed and target attenuation maps for comparison wname: str the name of the display window action: Callable[int, bool] Add control behaviour by providing a callable reacting to key presses. If it returns true, the display is stopped. _print: bool if measures for predictions should be printed """ timeout = 100 # disable print if not wanted print_func = print if _print else lambda x: x # print header print_func(" Id | NMAE | MSE") print_func("----|----------|---------") for i, (recon, mu_map) in enumerate(dataset): _id = dataset.table.iloc[i][] prediction = model(recon.unsqueeze(dim=0)).squeeze().numpy() mu_map = mu_map.squeeze().numpy() _nmae = nmae(prediction, mu_map) _mse = mse(prediction, mu_map) print_func(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}") prediction = np.clip(prediction, 0, mu_map.max()) diff = np.abs(prediction - mu_map) volumes = [prediction, mu_map, diff] min_val = 0 max_val = mu_map.max() n_slices = mu_map.shape[0] _slice = 0 _break_outer = False timeout = 100 show_text = True while True: images = map(lambda v: v[_slice], volumes) images = map( lambda img: to_grayscale(img, min_val=min_val, max_val=max_val), images ) images = map(lambda img: cv.resize(img, (512, 512)), images) images = list(images) if show_text: txt = f"{str(_slice):{len(str(n_slices))}}/{n_slices}" cv.putText(images[0], txt, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, 255, 3) _slice = (_slice + 1) % n_slices cv.imshow(wname, join_images(images)) key = cv.waitKey(timeout) if action is not None and action(key): _break_outer = True break if key == ord("q"): exit(0) elif key == ord("n"): break elif key == 81: _slice = (_slice - 2) % n_slices elif key == ord("p"): timeout = 0 if timeout > 0 else 100 elif key == ord("t"): show_text = not show_text _slice = (_slice - 1) % n_slices elif key == ord("s"): _slice = (_slice - 1) % n_slices cv.imwrite("prediction.png", images[0]) cv.imwrite("mu_map.png", images[1]) cv.imwrite("difference.png", images[2]) with open("info.txt", mode="w") as f: f.write(f"Arguments: {args}\n") f.write("\n") f.write(f"Id: {_id}\n") f.write(f"Slice: {_slice}\n") if _break_outer: break if __name__ == "__main__": import argparse import os parser = argparse.ArgumentParser( description="visualize the results of a random search run", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dir", type=str, help="directory containing the data of the random search run" ) parser.add_argument( "--split", choices=["train", "validation", "test"], default="validation", help="the split of the dataset used", ) args = parser.parse_args() params = load_params(os.path.join(args.dir, "params.json")) dataset = MuMapDataset( "data/second/", transform_normalization=SequenceTransform( [params["normalization"], PadCropTranform(dim=3, size=32)] ), split_name=args.split, scatter_correction=False, ) device = torch.device("cpu") model = UNet(features=params["generator_features"]) weights = torch.load( os.path.join(args.dir, "snapshots", "val_min_generator.pth"), map_location=device, ) model.load_state_dict(weights) model = wname = "Prediction" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) main(model, dataset, wname=wname)