Skip to content
Snippets Groups Projects
Commit fa4db774 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add script to show predictions of a random search

parent b5e64aeb
No related branches found
No related tags found
No related merge requests found
import cv2 as cv
import numpy as np
import torch
torch.set_grad_enabled(False)
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.normalization import (
GaussianNormTransform,
MeanNormTransform,
MaxNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, SequenceTransform
from mu_map.eval.measures import nmae, mse
from mu_map.models.unet import UNet
from mu_map.util import to_grayscale, COLOR_WHITE
from mu_map.vis.slices import join_images
from mu_map.random_search.cgan import load_params
def main(model: torch.nn.Module, dataset: MuMapDataset):
"""
Visualize the predictions of a model for all reconstructions
in a dataset.
"""
timeout = 100
print(" Id | NMAE | MSE")
print("----|----------|---------")
for i, (recon, mu_map) in enumerate(dataset):
_id = dataset.table.iloc[i]["id"]
prediction = model(recon.unsqueeze(dim=0)).squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
_nmae = nmae(prediction, mu_map)
_mse = mse(prediction, mu_map)
print(f"{_id:03d} | {_nmae:.6f} | {_mse:.6f}")
prediction = np.clip(prediction, 0, prediction.max())
diff = np.abs(prediction - mu_map)
volumes = [prediction, mu_map, diff]
min_val = 0
max_val = mu_map.max()
n_slices = mu_map.shape[0]
_slice = 0
while True:
images = map(lambda v: v[_slice], volumes)
images = map(
lambda img: to_grayscale(img, min_val=min_val, max_val=max_val), images
)
images = map(lambda img: cv.resize(img, (512, 512)), images)
images = list(images)
txt = f"{str(_slice):{len(str(n_slices))}}/{n_slices}"
cv.putText(images[0], txt, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, 255, 3)
_slice = (_slice + 1) % n_slices
cv.imshow(wname, join_images(images))
key = cv.waitKey(100)
if key == ord("q"):
exit(0)
elif key == ord("n"):
break
elif key == 81:
_slice = (_slice - 2) % n_slices
elif key == ord("p"):
timeout = 0 if timeout > 0 else 100
if __name__ == "__main__":
import argparse
import os
parser = argparse.ArgumentParser(
description="visualize the results of a random search run",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dir", type=str, help="directory containing the data of the random search run"
)
parser.add_argument(
"--split",
choices=["train", "validation", "test"],
default="validation",
help="the split of the dataset used",
)
args = parser.parse_args()
params = load_params(os.path.join(args.dir, "params.json"))
dataset = MuMapDataset(
"data/second/",
transform_normalization=SequenceTransform(
[params["normalization"], PadCropTranform(dim=3, size=32)]
),
split_name=args.split,
scatter_correction=False,
)
device = torch.device("cpu")
model = UNet(features=params["generator_features"])
weights = torch.load(
os.path.join(args.dir, "snapshots", "val_min_generator.pth"),
map_location=device,
)
model.load_state_dict(weights)
model = model.to(device).eval()
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
main(model, dataset)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment