Skip to content
Snippets Groups Projects
default.py 8.89 KiB
import os
from typing import Optional

import cv2 as cv
import pandas as pd
import pydicom
import numpy as np
import torch
from torch.utils.data import Dataset

from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.logging import get_logger


"""
Since DICOM images only allow images stored in short integer format,
the Siemens scanner software multiplies values by a factor before storing
so that no precision is lost.
The scale can be found in this private DICOM tag.
"""
DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038


def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
    """
    Align one image to another on the first axis (z-axis).
    It is assumed that the second image has less slices than the first.
    Then, the first image is shortened in a way that the centers of both images lie on top of each other.

    :param image_1: the image to be aligned
    :param image_2: the image to which image_1 is aligned
    :return: the aligned image_1
    """
    assert (
        image_1.shape[0] > image_2.shape[0]
    ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"

    # central slice of image 2
    c_2 = image_2.shape[0] // 2
    # image to the left and right of the center
    left = c_2
    right = image_2.shape[0] - c_2

    # central slice of image 1
    c_1 = image_1.shape[0] // 2
    # select center and same amount to the left/right as image_2
    return image_1[(c_1 - left) : (c_1 + right)]


class MuMapDataset(Dataset):
    def __init__(
        self,
        dataset_dir: str,
        csv_file: str = "meta.csv",
        split_file: str = "split.csv",
        split_name: str = None,
        images_dir: str = "images",
        bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
        discard_mu_map_slices: bool = True,
        align: bool = True,
        scatter_correction: bool = False,
        transform_normalization: Transform = Transform(),
        transform_augmentation: Transform = Transform(),
        logger=None,
    ):
        super().__init__()

        self.dir = dataset_dir
        self.dir_images = os.path.join(dataset_dir, images_dir)
        self.csv_file = os.path.join(dataset_dir, csv_file)
        self.split_file = os.path.join(dataset_dir, split_file)

        self.transform_normalization = transform_normalization
        self.transform_augmentation = transform_augmentation
        self.logger = logger if logger is not None else get_logger()

        self.bed_contours_file = (
            os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
        )
        self.bed_contours = (
            load_contours(self.bed_contours_file) if bed_contours_file else None
        )

        # read CSV file and from that access DICOM files
        self.table = pd.read_csv(self.csv_file)
        if split_name:
            self.table = split_csv(self.table, self.split_file)[split_name]
        self.table["id"] = self.table["id"].apply(int)

        self.discard_mu_map_slices = discard_mu_map_slices
        self.align = align
        self.scatter_correction = scatter_correction
        self.header_recon = headers.file_recon_nac_sc if self.scatter_correction else headers.file_recon_nac_nsc

        self.reconstructions = {}
        self.mu_maps = {}
        self.pre_load_images()

    def pre_load_images(self):
        self.logger.debug("Pre-loading images ...")
        for i in range(len(self.table)):
            row = self.table.iloc[i]
            _id = row["id"]

            mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
            mu_map = pydicom.dcmread(mu_map_file)
            mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value
            if self.discard_mu_map_slices:
                mu_map = discard_slices(row, mu_map)
            if self.bed_contours:
                bed_contour = self.bed_contours[row["id"]]
                for i in range(mu_map.shape[0]):
                    mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)

            recon_file = os.path.join(self.dir_images, row[self.header_recon])
            recon = pydicom.dcmread(recon_file)
            recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value
            if self.align:
                recon = align_images(recon, mu_map)

            mu_map = mu_map.astype(np.float32)
            mu_map = torch.from_numpy(mu_map)
            mu_map = mu_map.unsqueeze(dim=0)

            recon = recon.astype(np.float32)
            recon = torch.from_numpy(recon)
            recon = recon.unsqueeze(dim=0)

            recon, mu_map = self.transform_normalization(recon, mu_map)

            self.mu_maps[_id] = mu_map
            self.reconstructions[_id] = recon
        self.logger.debug("Pre-loading images done!")

    def __getitem__(self, index: int):
        row = self.table.iloc[index]
        _id = row["id"]

        recon = self.reconstructions[_id]
        mu_map = self.mu_maps[_id]

        recon, mu_map = self.transform_augmentation(recon, mu_map)

        return recon, mu_map

    def __len__(self):
        return len(self.table)


__all__ = [MuMapDataset.__name__]

def main(dataset):
    from mu_map.util import to_grayscale, COLOR_WHITE

    wname = "Dataset"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    space = np.full((1024, 10), 239, np.uint8)

    timeout = 100

    def to_display_image(image, _slice):
        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
        _image = cv.putText(
            _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
        )
        return _image

    def combine_images(images, slices):
        image_1 = to_display_image(images[0], slices[0])
        image_2 = to_display_image(images[1], slices[1])
        space = np.full((image_1.shape[0], 10), 239, np.uint8)
        return np.hstack((image_1, space, image_2))

    for i in range(len(dataset)):
        ir = 0
        im = 0

        recon, mu_map = dataset[i]
        recon = recon.squeeze().numpy()
        mu_map = mu_map.squeeze().numpy()
        print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")

        cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
        key = cv.waitKey(timeout)

        running = 0
        while True:
            ir = (ir + 1) % recon.shape[0]
            im = (im + 1) % mu_map.shape[0]

            to_show = combine_images((recon, mu_map), (ir, im))
            cv.imshow(wname, to_show)
            key = cv.waitKey(timeout)

            if key == ord("n"):
                break
            elif key == ord("q"):
                exit(0)
            elif key == ord("p"):
                timeout = 0 if timeout > 0 else 100
            elif key == 83:  # right arrow key
                continue
            elif key == 81:  # left arrow key
                ir = max(ir - 2, 0)
                im = max(im - 2, 0)
            elif key == ord("s"):
                cv.imwrite(f"{running:03d}.png", to_show)
                running += 1


if __name__ == "__main__":
    import argparse

    from mu_map.logging import add_logging_args, get_logger_by_args

    parser = argparse.ArgumentParser(
        description="Visualize the images of a MuMapDataset",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
    )
    parser.add_argument(
        "--split",
        type=str,
        choices=["train", "validation", "test"],
        help="choose the split of the data for the dataset",
    )
    parser.add_argument(
        "--unaligned",
        action="store_true",
        help="do not perform center alignment of reconstruction an mu-map slices",
    )
    parser.add_argument(
        "--show_bed",
        action="store_true",
        help="do not remove the bed contour from the mu map",
    )
    parser.add_argument(
        "--full_mu_map",
        action="store_true",
        help="do not remove broken slices of the mu map",
    )
    add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
    args = parser.parse_args()

    align = not args.unaligned
    discard_mu_map_slices = not args.full_mu_map
    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
    logger = get_logger_by_args(args)

    dataset = MuMapDataset(
        args.dataset_dir,
        align=align,
        discard_mu_map_slices=discard_mu_map_slices,
        bed_contours_file=bed_contours_file,
        split_name=args.split,
        logger=logger,
    )
    main(dataset)