Tamino Huxohl authoredTamino Huxohl authored
slices.py 4.56 KiB
from typing import List, Optional
import numpy as np
from mu_map.dataset.util import align_images
from mu_map.file.util import load_img
def join_images(
images: List[np.ndarray],
separator: Optional[np.ndarray] = None,
vertical: bool = False,
) -> np.ndarray:
Create a new image by joining all input images along a separator image.
If joining horizontally, their shape must be equal along their first axis.
If joining vertically, their shape must be equal along their second axis.
images: list of np.ndarray
a list of images to join
separator: np.ndarray, optional
a separator image inserted between all images - if None a default separator
with a width of 10 pixels is created
vertical: bool
join images vertically instead of horizontally
a new image joined as described above
if separator is None:
shape = (
(10, *images[0].shape[1:])
if vertical
else (images[0].shape[0], 10, *images[0].shape[2:])
separator = np.full(shape, 239, np.uint8)
res = []
for image in images:
res += [image, separator]
return np.vstack(res[:-1]) if vertical else np.hstack(res[:-1])
if __name__ == "__main__":
import argparse
import cv2 as cv
from mu_map.util import to_grayscale, COLOR_WHITE
parser = argparse.ArgumentParser(
description="Visualize 3D Volumes as a video of their slices",
parser.add_argument("images", type=str, nargs="+", help="the images to visualize")
"--resize", type=int, default=512, help="resize images to this size"
"--fps", type=int, default=10, help="frames (slices) to show per second"
parser.add_argument("--align", action="store_true", help="center align images")
help="normalize all images to the same value range",
help="join images vertically instead of horizontally",
"--window_name", type=str, default="Slices", help="name of the displayed window"
args = parser.parse_args()
images = list(map(load_img, args.images))
if args.align:
image_with_least_slices = sorted(images, key=lambda image: image.shape[0])[0]
images = list(
map(lambda image: align_images(image, image_with_least_slices)[0], images)
scales = list(
lambda image: args.resize / image.shape[2]
if args.vertical
else args.resize / image.shape[1],
slices = [0] * len(images)
space = (
np.full((10, args.resize), 239, np.uint8)
if args.vertical
else np.full((args.resize, 10), 239, np.uint8)
min_vals = list(map(lambda image: image.min(), images))
max_vals = list(map(lambda image: image.max(), images))
if args.shared_range:
min_vals = [min(min_vals)] * len(images)
max_vals = [max(max_vals)] * len(images)
cv.namedWindow(args.window_name, cv.WINDOW_NORMAL)
cv.resizeWindow(args.window_name, 1600, 900)
timeout = 1000 // args.fps
current_timeout = timeout
while True:
_images = []
for i, (image, _slice, scale, min_val, max_val) in enumerate(
zip(images, slices, scales, min_vals, max_vals)
_image = to_grayscale(image[_slice], min_val=min_val, max_val=max_val)
_image = cv.resize(_image, None, fx=scale, fy=scale)
_image = cv.putText(
str(_slice + 1),
(0, 30),
slices[i] = (_slice + 1) % image.shape[0]
image = join_images(_images, space, vertical=args.vertical)
cv.imshow(args.window_name, image)
key = cv.waitKey(current_timeout)
if key == ord("q"):
elif key == ord("p"):
current_timeout = 0 if current_timeout > 0 else timeout
elif key == 81 or key == 84: # back or down arrow keys
for i, (image, _slice) in enumerate(zip(images, slices)):
slices[i] = (_slice - 2) % image.shape[0]