Skip to content
Snippets Groups Projects
datasets.py 3.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    import os
    
    import pandas as pd
    import pydicom
    
    import numpy as np
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    from torch.utils.data import Dataset
    
    
    
    class MuMapDataset(Dataset):
        def __init__(
            self, dataset_dir: str, csv_file: str = "meta.csv", images_dir: str = "images"
        ):
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            super().__init__()
    
    
            self.dir = dataset_dir
            self.dir_images = os.path.join(dataset_dir, images_dir)
            self.csv_file = os.path.join(dataset_dir, csv_file)
    
            self.table = pd.read_csv(self.csv_file)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            # read csv file and from that access dicom files
    
    
        def __getitem__(self, index: int):
            row = self.table.iloc[index]
    
            recon_file = os.path.join(self.dir_images, row["file_recon_no_ac"])
            mu_map_file = os.path.join(self.dir_images, row["file_mu_map"])
    
            recon = pydicom.dcmread(recon_file).pixel_array
            mu_map = pydicom.dcmread(mu_map_file).pixel_array
            recon, mu_map = self.align(recon, mu_map)
    
            return recon, mu_map
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
        def __len__(self):
    
            return len(self.table)
    
        def align(self, recon, mu_map):
            assert recon.shape[0] > mu_map.shape[0], f"Alignment is based on the fact that the NoAC Recon has more slices {recon.shape[0]} than the attenuation map {mu_map.shape[0]}"
    
            cm = mu_map.shape[0] // 2
    
            left = cm
            right = mu_map.shape[0] - cm
    
            cr = recon.shape[0] // 2
            recon = recon[(cr - left):(cr + right)]
            return recon, mu_map
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    
    __all__ = [MuMapDataset.__name__]
    
    
    if __name__ == "__main__":
        dataset = MuMapDataset("data/tmp")
        print(f"Images: {len(dataset)}")
    
        import cv2 as cv
    
        wname = "Images"
        cv.namedWindow(wname, cv.WINDOW_NORMAL)
        cv.resizeWindow(wname, 1024, 512)
        space = np.full((128, 10), 239, np.uint8)
    
        def to_grayscale(img: np.ndarray, min_val=None, max_val=None):
            if min_val is None:
                min_val = img.min()
    
            if max_val is None:
                max_val = img.max()
    
            _img = (img - min_val) / (max_val - min_val)
            _img = (_img * 255).astype(np.uint8)
            return _img
    
        for i in range(len(dataset)):
            ir = 0
            im = 0
    
            recon, mu_map = dataset[i]
            print(f"{i+1}/{len(dataset)} - {recon.shape} - {mu_map.shape}")
    
            to_show = np.hstack(
                (
                    to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
                    space,
                    to_grayscale(mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()),
                )
            )
            cv.imshow(wname, to_show)
            key = cv.waitKey(100)
    
            while True:
                ir = (ir + 1) % recon.shape[0]
                im = (im + 1) % mu_map.shape[0]
    
                to_show = np.hstack(
                    (
                        to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()),
                        space,
                        to_grayscale(
                            mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()
                        ),
                    )
                )
                cv.imshow(wname, to_show)
    
                key = cv.waitKey(100)
    
                if key == ord("n"):
                    break
                if key == ord("q"):
                    exit(0)