Skip to content
Snippets Groups Projects
datasets.py 4.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    import os
    
    import pandas as pd
    import pydicom
    
    import numpy as np
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    from torch.utils.data import Dataset
    
    
    
    HEADER_DISC_FIRST = "discard_first"
    HEADER_DISC_LAST = "discard_last"
    
    
    def discard_slices(row, μ_map):
        """
        Discard slices based on the flags in the row of th according table.
        The row is expected to contain the flags 'discard_first' and 'discard_last'.
    
        :param row: the row of meta configuration file of a dataset
        :param μ_map: the μ_map
        :return: the μ_map with according slices removed
        """
        _res = μ_map
    
        if row[HEADER_DISC_FIRST]:
            _res  = _res[1:]
    
        if row[HEADER_DISC_LAST]:
            _res = _res[:-1]
    
        return _res
    
    
    def align_images(image_1: np.ndarray, image_2: np.ndarray):
        """
        Align one image to another on the first axis (z-axis).
        It is assumed that the second image has less slices than the first.
        Then, the first image is shortened in a way that the centers of both images lie on top of each other.
    
        :param image_1: the image to be aligned
        :param image_2: the image to which image_1 is aligned
        :return: the aligned image_1
        """
        assert (
            image_1.shape[0] > image_2.shape[0]
        ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"
    
        # central slice of image 2
        c_2 = image_2.shape[0] // 2
        # image to the left and right of the center
        left = c_2
        right = image_2.shape[0] - c_2
    
        # central slice of image 1
        c_1 = image_1.shape[0] // 2
        # select center and same amount to the left/right as image_2
        return image_1[(c_1 - left) : (c_1 + right)]
    
    
    
    class MuMapDataset(Dataset):
        def __init__(
    
            self,
            dataset_dir: str,
            csv_file: str = "meta.csv",
            images_dir: str = "images",
            discard_μ_map_slices: bool = True,
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            super().__init__()
    
    
            self.dir = dataset_dir
            self.dir_images = os.path.join(dataset_dir, images_dir)
            self.csv_file = os.path.join(dataset_dir, csv_file)
    
    
            # read CSV file and from that access DICOM files
    
            self.table = pd.read_csv(self.csv_file)
    
    
            self.discard_μ_map_slices = discard_μ_map_slices
    
        def __getitem__(self, index: int):
            row = self.table.iloc[index]
    
            recon_file = os.path.join(self.dir_images, row["file_recon_no_ac"])
            mu_map_file = os.path.join(self.dir_images, row["file_mu_map"])
    
            recon = pydicom.dcmread(recon_file).pixel_array
            mu_map = pydicom.dcmread(mu_map_file).pixel_array
    
    
            if self.discard_μ_map_slices:
                mu_map = discard_slices(row, mu_map)
    
            recon = align_images(recon, mu_map)
    
    
            return recon, mu_map
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
        def __len__(self):
    
            return len(self.table)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    __all__ = [MuMapDataset.__name__]
    
    
    if __name__ == "__main__":
        dataset = MuMapDataset("data/tmp")
        print(f"Images: {len(dataset)}")
    
        import cv2 as cv
    
        wname = "Images"
        cv.namedWindow(wname, cv.WINDOW_NORMAL)
        cv.resizeWindow(wname, 1024, 512)
        space = np.full((128, 10), 239, np.uint8)
    
        def to_grayscale(img: np.ndarray, min_val=None, max_val=None):
            if min_val is None:
                min_val = img.min()
    
            if max_val is None:
                max_val = img.max()
    
            _img = (img - min_val) / (max_val - min_val)
            _img = (_img * 255).astype(np.uint8)
            return _img
    
        for i in range(len(dataset)):
            ir = 0
            im = 0
    
            recon, mu_map = dataset[i]
            print(f"{i+1}/{len(dataset)} - {recon.shape} - {mu_map.shape}")
    
            to_show = np.hstack(
                (
                    to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
                    space,
                    to_grayscale(mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()),
                )
            )
            cv.imshow(wname, to_show)
            key = cv.waitKey(100)
    
            while True:
                ir = (ir + 1) % recon.shape[0]
                im = (im + 1) % mu_map.shape[0]
    
                to_show = np.hstack(
                    (
                        to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
                        space,
                        to_grayscale(
                            mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()
                        ),
                    )
                )
                cv.imshow(wname, to_show)
    
                key = cv.waitKey(100)
    
                if key == ord("n"):
                    break
                if key == ord("q"):
                    exit(0)