Skip to content
Snippets Groups Projects
datasets.py 7.16 KiB
import os
from typing import Optional

import pandas as pd
import pydicom
import numpy as np
from torch.utils.data import Dataset

from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices


def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
    """
    Align one image to another on the first axis (z-axis).
    It is assumed that the second image has less slices than the first.
    Then, the first image is shortened in a way that the centers of both images lie on top of each other.

    :param image_1: the image to be aligned
    :param image_2: the image to which image_1 is aligned
    :return: the aligned image_1
    """
    assert (
        image_1.shape[0] > image_2.shape[0]
    ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"

    # central slice of image 2
    c_2 = image_2.shape[0] // 2
    # image to the left and right of the center
    left = c_2
    right = image_2.shape[0] - c_2

    # central slice of image 1
    c_1 = image_1.shape[0] // 2
    # select center and same amount to the left/right as image_2
    return image_1[(c_1 - left) : (c_1 + right)]


class MuMapDataset(Dataset):
    def __init__(
        self,
        dataset_dir: str,
        csv_file: str = "meta.csv",
        images_dir: str = "images",
        bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
        discard_mu_map_slices: bool = True,
        align: bool = True,
    ):
        super().__init__()

        self.dir = dataset_dir
        self.dir_images = os.path.join(dataset_dir, images_dir)
        self.csv_file = os.path.join(dataset_dir, csv_file)

        self.bed_contours_file = (
            os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
        )
        self.bed_contours = (
            load_contours(self.bed_contours_file) if bed_contours_file else None
        )

        # read CSV file and from that access DICOM files
        self.table = pd.read_csv(self.csv_file)
        self.table["id"] = self.table["id"].apply(int)

        self.discard_mu_map_slices = discard_mu_map_slices
        self.align = align

        self.reconstructions = {}
        self.mu_maps = {}
        self.pre_load_images()

    def pre_load_images(self):
        print("Pre-loading images ...", end="\r")
        for i in range(len(self.table)):
            row = self.table.iloc[i]
            _id = row["id"]

            mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
            mu_map = pydicom.dcmread(mu_map_file).pixel_array
            if self.discard_mu_map_slices:
                mu_map = discard_slices(row, mu_map)
            if self.bed_contours:
                bed_contour = self.bed_contours[row["id"]]
                for i in range(mu_map.shape[0]):
                    mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
            self.mu_maps[_id] = mu_map

            recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
            recon = pydicom.dcmread(recon_file).pixel_array
            if self.align:
                recon = align_images(recon, mu_map)
            self.reconstructions[_id] = recon
        print("Pre-loading images done!")

    def __getitem__(self, index: int):
        row = self.table.iloc[index]
        _id = row["id"]

        recon = self.reconstructions[_id]
        mu_map = self.mu_maps[_id]

        # recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
        # mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])

        # recon = pydicom.dcmread(recon_file).pixel_array
        # mu_map = pydicom.dcmread(mu_map_file).pixel_array

        # if self.discard_mu_map_slices:
        # mu_map = discard_slices(row, mu_map)

        # if self.bed_contours:
        # bed_contour = self.bed_contours[row["id"]]
        # for i in range(mu_map.shape[0]):
        # mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)

        # if self.align:
        # recon = align_images(recon, mu_map)

        return recon, mu_map

    def __len__(self):
        return len(self.table)


__all__ = [MuMapDataset.__name__]

if __name__ == "__main__":
    import argparse

    import cv2 as cv

    from mu_map.util import to_grayscale, COLOR_WHITE

    parser = argparse.ArgumentParser(
        description="Visualize the images of a MuMapDataset",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dataset_dir", type=str, help="the directory from which the dataset is loaded"
    )
    parser.add_argument(
        "--unaligned",
        action="store_true",
        help="do not perform center alignment of reconstruction an mu-map slices",
    )
    parser.add_argument(
        "--show_bed",
        action="store_true",
        help="do not remove the bed contour from the mu map",
    )
    parser.add_argument(
        "--full_mu_map",
        action="store_true",
        help="do not remove broken slices of the mu map",
    )
    args = parser.parse_args()

    align = not args.unaligned
    discard_mu_map_slices = not args.full_mu_map
    bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME

    dataset = MuMapDataset(
        args.dataset_dir,
        align=align,
        discard_mu_map_slices=discard_mu_map_slices,
        bed_contours_file=bed_contours_file,
    )

    wname = "Dataset"
    cv.namedWindow(wname, cv.WINDOW_NORMAL)
    cv.resizeWindow(wname, 1600, 900)
    space = np.full((1024, 10), 239, np.uint8)

    timeout = 100

    def to_display_image(image, _slice):
        _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
        _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
        _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
        _image = cv.putText(
            _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
        )
        return _image

    def combine_images(images, slices):
        image_1 = to_display_image(images[0], slices[0])
        image_2 = to_display_image(images[1], slices[1])
        space = np.full((image_1.shape[0], 10), 239, np.uint8)
        return np.hstack((image_1, space, image_2))

    for i in range(len(dataset)):
        ir = 0
        im = 0

        recon, mu_map = dataset[i]
        print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")

        cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
        key = cv.waitKey(timeout)

        while True:
            ir = (ir + 1) % recon.shape[0]
            im = (im + 1) % mu_map.shape[0]

            cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))

            key = cv.waitKey(timeout)

            if key == ord("n"):
                break
            elif key == ord("q"):
                exit(0)
            elif key == ord("p"):
                timeout = 0 if timeout > 0 else 100
            elif key == 83:  # right arrow key
                continue
            elif key == 81:  # left arrow key
                ir = max(ir - 2, 0)
                im = max(im - 2, 0)