Skip to content
Snippets Groups Projects
Commit 032ccb09 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update and fix create video script

parent 0b1f8aa0
No related branches found
No related tags found
No related merge requests found
"""
Script to create a video of mu map slices compared to their predictions.
"""
import argparse
import os
import cv2 as cv
......@@ -13,29 +17,72 @@ from mu_map.vis.slices import join_images
torch.set_grad_enabled(False)
random_search_iter_dir = "cgan_random_search/001"
parser = argparse.ArgumentParser(
description="Create a video similar to the visualization of the mu_map.random_search.eval.show_predictions script",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dataset_dir", default="data/second", type=str, help="directory of the dataset"
)
parser.add_argument(
"--random_search_dir",
default="cgan_random_search/best",
type=str,
help="directory of the random search iteration",
)
parser.add_argument(
"--params",
default="params.json",
type=str,
help="file under <random_search_dir> containing the parameters",
)
parser.add_argument(
"--weights",
default="snapshots/val_min_generator.pth",
type=str,
help="file under <random_search_dir> containing the model weights",
)
parser.add_argument("--id", default=1, type=int, help="id of the study to visualize")
parser.add_argument(
"--fps", default=25, type=int, help="frames per second of the resulting video"
)
parser.add_argument(
"--slices_per_frame",
default=2,
type=int,
help="how often a slice should be repeated",
)
parser.add_argument("--size", default=512, type=int, help="size for a single image")
parser.add_argument(
"--hide_slice_number", action="store_true", help="do not print the slice number"
)
parser.add_argument(
"--out",
type=str,
default="mu_map_comparison.mp4",
help="the filename of the output video",
)
args = parser.parse_args()
params = load_params(os.path.join(random_search_iter_dir, "params.json"))
print(params["normalization"])
print(params["generator_features"])
args.params = os.path.join(args.random_search_dir, args.params)
args.weights = os.path.join(args.random_search_dir, args.weights)
params = load_params(args.params)
dataset = MuMapDataset(
"data/second/",
args.dataset_dir,
transform_normalization=SequenceTransform(
[params["normalization"], PadCropTranform(dim=3, size=32)]
),
split_name="test",
scatter_correction=False,
)
recon, mu_map_ct = dataset[1]
mu_map_ct = mu_map_ct.squeeze().numpy()
device = torch.device("cpu")
model = UNet(features=params["generator_features"])
model.load_state_dict(
torch.load(
os.path.join(random_search_iter_dir, "snapshots/val_min_generator.pth"), map_location=device
)
)
model.load_state_dict(torch.load(args.weights, map_location=device))
model = model.eval()
recon, mu_map_ct = dataset.get_item_by_id(args.id)
mu_map_ct = mu_map_ct.squeeze().numpy()
mu_map_dl = model(recon.unsqueeze(dim=0)).squeeze().numpy()
mu_map_dl = np.clip(mu_map_dl, 0, mu_map_ct.max())
......@@ -43,25 +90,26 @@ mu_map_dl = np.clip(mu_map_dl, 0, mu_map_ct.max())
volumes = [mu_map_ct, mu_map_dl, np.abs(mu_map_dl - mu_map_ct)]
min_val = 0
max_val = max(mu_map_ct.max(), mu_map_dl.max())
print(mu_map_ct.max(), mu_map_dl.max())
fourcc = cv.VideoWriter_fourcc(*"mp4v")
frame_size = (512, 3 * 512 + 2 * 10)
print(f"Frame size {frame_size}")
# video_writer = cv.VideoWriter("mu_map_comparison.mp4", fourcc, 25, frame_size, isColor=False)
video_writer = cv.VideoWriter("mu_map_comparison.mp4", cv.VideoWriter_fourcc(*"mp4v"), 25, frame_size[::-1], isColor=False)
video_writer = cv.VideoWriter(
args.out,
cv.VideoWriter_fourcc(*"mp4v"),
args.fps,
(3 * args.size + 2 * 10, args.size),
isColor=False,
)
for i in range(mu_map_ct.shape[0]):
images = map(lambda volume: volume[i], volumes)
images = map(
lambda img: to_grayscale(img, min_val=min_val, max_val=max_val), images
)
images = map(lambda img: cv.resize(img, (512, 512)), images)
images = map(lambda img: cv.resize(img, (args.size, args.size)), images)
images = list(images)
txt = f"{str(i):{len(str(mu_map_ct.shape[0]))}}/{mu_map_ct.shape[0]}"
txt = f"{str(i + 1):{len(str(mu_map_ct.shape[0]))}}/{mu_map_ct.shape[0]}"
cv.putText(images[0], txt, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, 255, 3)
image = join_images(images)
for i in range(5):
for i in range(args.slices_per_frame):
video_writer.write(image)
video_writer.release()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment