from typing import List, Optional import numpy as np from mu_map.dataset.util import align_images from mu_map.file.util import load_img def join_images( images: List[np.ndarray], separator: Optional[np.ndarray] = None, vertical: bool = False, ) -> np.ndarray: """ Create a new image by joining all input images along a separator image. If joining horizontally, their shape must be equal along their first axis. If joining vertically, their shape must be equal along their second axis. Parameters ---------- images: list of np.ndarray a list of images to join separator: np.ndarray, optional a separator image inserted between all images - if None a default separator with a width of 10 pixels is created vertical: bool join images vertically instead of horizontally Returns ------- np.ndarray a new image joined as described above """ if separator is None: shape = ( (10, *images[0].shape[1:]) if vertical else (images[0].shape[0], 10, *images[0].shape[2:]) ) separator = np.full(shape, 239, np.uint8) res = [] for image in images: res += [image, separator] return np.vstack(res[:-1]) if vertical else np.hstack(res[:-1]) if __name__ == "__main__": import argparse import cv2 as cv from mu_map.util import to_grayscale, COLOR_WHITE parser = argparse.ArgumentParser( description="Visualize 3D Volumes as a video of their slices", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("images", type=str, nargs="+", help="the images to visualize") parser.add_argument( "--resize", type=int, default=512, help="resize images to this size" ) parser.add_argument( "--fps", type=int, default=10, help="frames (slices) to show per second" ) parser.add_argument("--align", action="store_true", help="center align images") parser.add_argument( "--shared_range", action="store_true", help="normalize all images to the same value range", ) parser.add_argument( "--vertical", action="store_true", help="join images vertically instead of horizontally", ) parser.add_argument( "--window_name", type=str, default="Slices", help="name of the displayed window" ) args = parser.parse_args() images = list(map(load_img, args.images)) if args.align: image_with_least_slices = sorted(images, key=lambda image: image.shape[0])[0] images = list( map(lambda image: align_images(image, image_with_least_slices)[0], images) ) scales = list( map( lambda image: args.resize / image.shape[2] if args.vertical else args.resize / image.shape[1], images, ) ) slices = [0] * len(images) space = ( np.full((10, args.resize), 239, np.uint8) if args.vertical else np.full((args.resize, 10), 239, np.uint8) ) min_vals = list(map(lambda image: image.min(), images)) max_vals = list(map(lambda image: image.max(), images)) if args.shared_range: min_vals = [min(min_vals)] * len(images) max_vals = [max(max_vals)] * len(images) cv.namedWindow(args.window_name, cv.WINDOW_NORMAL) cv.resizeWindow(args.window_name, 1600, 900) timeout = 1000 // args.fps current_timeout = timeout while True: _images = [] for i, (image, _slice, scale, min_val, max_val) in enumerate( zip(images, slices, scales, min_vals, max_vals) ): _image = to_grayscale(image[_slice], min_val=min_val, max_val=max_val) _image = cv.resize(_image, None, fx=scale, fy=scale) _image = cv.putText( _image, str(_slice + 1), (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3, ) _images.append(_image) slices[i] = (_slice + 1) % image.shape[0] image = join_images(_images, space, vertical=args.vertical) cv.imshow(args.window_name, image) key = cv.waitKey(current_timeout) if key == ord("q"): break elif key == ord("p"): current_timeout = 0 if current_timeout > 0 else timeout elif key == 81 or key == 84: # back or down arrow keys for i, (image, _slice) in enumerate(zip(images, slices)): slices[i] = (_slice - 2) % image.shape[0]