Skip to content
Snippets Groups Projects
slices.py 4.56 KiB
from typing import List, Optional

import numpy as np

from mu_map.dataset.util import align_images
from mu_map.file.util import load_img


def join_images(
    images: List[np.ndarray],
    separator: Optional[np.ndarray] = None,
    vertical: bool = False,
) -> np.ndarray:
    """
    Create a new image by joining all input images along a separator image.
    If joining horizontally, their shape must be equal along their first axis.
    If joining vertically, their shape must be equal along their second axis.

    Parameters
    ----------
    images: list of np.ndarray
        a list of images to join
    separator: np.ndarray, optional
        a separator image inserted between all images - if None a default separator
        with a width of 10 pixels is created
    vertical: bool
        join images vertically instead of horizontally

    Returns
    -------
    np.ndarray
        a new image joined as described above
    """
    if separator is None:
        shape = (
            (10, *images[0].shape[1:])
            if vertical
            else (images[0].shape[0], 10, *images[0].shape[2:])
        )
        separator = np.full(shape, 239, np.uint8)

    res = []
    for image in images:
        res += [image, separator]
    return np.vstack(res[:-1]) if vertical else np.hstack(res[:-1])


if __name__ == "__main__":
    import argparse

    import cv2 as cv

    from mu_map.util import to_grayscale, COLOR_WHITE

    parser = argparse.ArgumentParser(
        description="Visualize 3D Volumes as a video of their slices",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument("images", type=str, nargs="+", help="the images to visualize")
    parser.add_argument(
        "--resize", type=int, default=512, help="resize images to this size"
    )
    parser.add_argument(
        "--fps", type=int, default=10, help="frames (slices) to show per second"
    )
    parser.add_argument("--align", action="store_true", help="center align images")
    parser.add_argument(
        "--shared_range",
        action="store_true",
        help="normalize all images to the same value range",
    )
    parser.add_argument(
        "--vertical",
        action="store_true",
        help="join images vertically instead of horizontally",
    )
    parser.add_argument(
        "--window_name", type=str, default="Slices", help="name of the displayed window"
    )
    args = parser.parse_args()

    images = list(map(load_img, args.images))

    if args.align:
        image_with_least_slices = sorted(images, key=lambda image: image.shape[0])[0]
        images = list(
            map(lambda image: align_images(image, image_with_least_slices)[0], images)
        )

    scales = list(
        map(
            lambda image: args.resize / image.shape[2]
            if args.vertical
            else args.resize / image.shape[1],
            images,
        )
    )
    slices = [0] * len(images)
    space = (
        np.full((10, args.resize), 239, np.uint8)
        if args.vertical
        else np.full((args.resize, 10), 239, np.uint8)
    )

    min_vals = list(map(lambda image: image.min(), images))
    max_vals = list(map(lambda image: image.max(), images))
    if args.shared_range:
        min_vals = [min(min_vals)] * len(images)
        max_vals = [max(max_vals)] * len(images)

    cv.namedWindow(args.window_name, cv.WINDOW_NORMAL)
    cv.resizeWindow(args.window_name, 1600, 900)

    timeout = 1000 // args.fps
    current_timeout = timeout
    while True:
        _images = []
        for i, (image, _slice, scale, min_val, max_val) in enumerate(
            zip(images, slices, scales, min_vals, max_vals)
        ):
            _image = to_grayscale(image[_slice], min_val=min_val, max_val=max_val)
            _image = cv.resize(_image, None, fx=scale, fy=scale)
            _image = cv.putText(
                _image,
                str(_slice + 1),
                (0, 30),
                cv.FONT_HERSHEY_SIMPLEX,
                1,
                COLOR_WHITE,
                3,
            )
            _images.append(_image)

            slices[i] = (_slice + 1) % image.shape[0]
        image = join_images(_images, space, vertical=args.vertical)

        cv.imshow(args.window_name, image)
        key = cv.waitKey(current_timeout)
        if key == ord("q"):
            break
        elif key == ord("p"):
            current_timeout = 0 if current_timeout > 0 else timeout
        elif key == 81 or key == 84:  # back or down arrow keys
            for i, (image, _slice) in enumerate(zip(images, slices)):
                slices[i] = (_slice - 2) % image.shape[0]