Skip to content
Snippets Groups Projects
datasets.py 7.32 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    import os
    
    from typing import Optional
    
    import cv2 as cv
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    import pandas as pd
    import pydicom
    
    import numpy as np
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    from torch.utils.data import Dataset
    
    
    from mu_map.data.prepare import headers
    
    from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
    
    from mu_map.data.review_mu_map import discard_slices
    
    def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
    
        """
        Align one image to another on the first axis (z-axis).
        It is assumed that the second image has less slices than the first.
        Then, the first image is shortened in a way that the centers of both images lie on top of each other.
    
        :param image_1: the image to be aligned
        :param image_2: the image to which image_1 is aligned
        :return: the aligned image_1
        """
        assert (
            image_1.shape[0] > image_2.shape[0]
        ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"
    
        # central slice of image 2
        c_2 = image_2.shape[0] // 2
        # image to the left and right of the center
        left = c_2
        right = image_2.shape[0] - c_2
    
        # central slice of image 1
        c_1 = image_1.shape[0] // 2
        # select center and same amount to the left/right as image_2
        return image_1[(c_1 - left) : (c_1 + right)]
    
    
    
    class MuMapDataset(Dataset):
        def __init__(
    
            self,
            dataset_dir: str,
            csv_file: str = "meta.csv",
            images_dir: str = "images",
    
            bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
    
            discard_mu_map_slices: bool = True,
    
            align: bool = True,
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            super().__init__()
    
    
            self.dir = dataset_dir
            self.dir_images = os.path.join(dataset_dir, images_dir)
            self.csv_file = os.path.join(dataset_dir, csv_file)
    
    
            self.bed_contours_file = (
                os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
            )
            self.bed_contours = (
                load_contours(self.bed_contours_file) if bed_contours_file else None
            )
    
            # read CSV file and from that access DICOM files
    
            self.table = pd.read_csv(self.csv_file)
    
            self.table["id"] = self.table["id"].apply(int)
    
            self.discard_mu_map_slices = discard_mu_map_slices
    
            self.align = align
    
            self.reconstructions = {}
            self.mu_maps = {}
            self.pre_load_images()
    
        def pre_load_images(self):
    
            print("Pre-loading images ...", end="\r")
    
            for i in range(len(self.table)):
    
                row = self.table.iloc[i]
    
                _id = row["id"]
    
                mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
    
                mu_map = pydicom.dcmread(mu_map_file).pixel_array
    
                if self.discard_mu_map_slices:
                    mu_map = discard_slices(row, mu_map)
                if self.bed_contours:
                    bed_contour = self.bed_contours[row["id"]]
                    for i in range(mu_map.shape[0]):
                        mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
                self.mu_maps[_id] = mu_map
    
                recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
    
                recon = pydicom.dcmread(recon_file).pixel_array
    
                if self.align:
                    recon = align_images(recon, mu_map)
                self.reconstructions[_id] = recon
    
            print("Pre-loading images done!")
    
        def __getitem__(self, index: int):
            row = self.table.iloc[index]
    
            _id = row["id"]
    
            recon = self.reconstructions[_id]
            mu_map = self.mu_maps[_id]
    
            # recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
            # mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
    
            # recon = pydicom.dcmread(recon_file).pixel_array
            # mu_map = pydicom.dcmread(mu_map_file).pixel_array
    
            # if self.discard_mu_map_slices:
    
            # mu_map = discard_slices(row, mu_map)
    
            # if self.bed_contours:
    
            # bed_contour = self.bed_contours[row["id"]]
            # for i in range(mu_map.shape[0]):
            # mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
    
            # if self.align:
    
            # recon = align_images(recon, mu_map)
    
    
            return recon, mu_map
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
        def __len__(self):
    
            return len(self.table)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    __all__ = [MuMapDataset.__name__]
    
    
    if __name__ == "__main__":
    
        import argparse
    
        from mu_map.util import to_grayscale, COLOR_WHITE
    
        parser = argparse.ArgumentParser(
            description="Visualize the images of a MuMapDataset",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "dataset_dir", type=str, help="the directory from which the dataset is loaded"
        )
        parser.add_argument(
            "--unaligned",
            action="store_true",
            help="do not perform center alignment of reconstruction an mu-map slices",
        )
        parser.add_argument(
            "--show_bed",
            action="store_true",
            help="do not remove the bed contour from the mu map",
        )
        parser.add_argument(
            "--full_mu_map",
            action="store_true",
            help="do not remove broken slices of the mu map",
        )
    
        args = parser.parse_args()
    
    
        align = not args.unaligned
        discard_mu_map_slices = not args.full_mu_map
        bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
    
    
        dataset = MuMapDataset(
            args.dataset_dir,
            align=align,
            discard_mu_map_slices=discard_mu_map_slices,
            bed_contours_file=bed_contours_file,
        )
    
    
        wname = "Dataset"
    
        cv.namedWindow(wname, cv.WINDOW_NORMAL)
    
        cv.resizeWindow(wname, 1600, 900)
        space = np.full((1024, 10), 239, np.uint8)
    
        timeout = 100
    
        def to_display_image(image, _slice):
            _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
            _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
            _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
    
            _image = cv.putText(
                _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
            )
    
            return _image
    
        def combine_images(images, slices):
            image_1 = to_display_image(images[0], slices[0])
            image_2 = to_display_image(images[1], slices[1])
            space = np.full((image_1.shape[0], 10), 239, np.uint8)
            return np.hstack((image_1, space, image_2))
    
    
        for i in range(len(dataset)):
            ir = 0
            im = 0
    
            recon, mu_map = dataset[i]
    
            print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")
    
            cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
            key = cv.waitKey(timeout)
    
            while True:
                ir = (ir + 1) % recon.shape[0]
                im = (im + 1) % mu_map.shape[0]
    
    
                to_show = combine_images((recon, mu_map), (ir, im))
                cv.imshow(wname, to_show)
    
    
                key = cv.waitKey(timeout)
    
    
                if key == ord("n"):
                    break
    
                elif key == ord("q"):
    
                elif key == ord("p"):
                    timeout = 0 if timeout > 0 else 100
    
                elif key == 83:  # right arrow key
    
                elif key == 81:  # left arrow key
    
                    ir = max(ir - 2, 0)
                    im = max(im - 2, 0)
    
                elif key == ord("s"):
                    cv.imwrite(f"{running:03d}.png", to_show)
                    running += 1