import os import pandas as pd import pydicom import numpy as np from torch.utils.data import Dataset class MuMapDataset(Dataset): def __init__( self, dataset_dir: str, csv_file: str = "meta.csv", images_dir: str = "images" ): super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) self.table = pd.read_csv(self.csv_file) # read csv file and from that access dicom files def __getitem__(self, index: int): row = self.table.iloc[index] recon_file = os.path.join(self.dir_images, row["file_recon_no_ac"]) mu_map_file = os.path.join(self.dir_images, row["file_mu_map"]) recon = pydicom.dcmread(recon_file).pixel_array mu_map = pydicom.dcmread(mu_map_file).pixel_array recon, mu_map = self.align(recon, mu_map) return recon, mu_map def __len__(self): return len(self.table) def align(self, recon, mu_map): assert recon.shape[0] > mu_map.shape[0], f"Alignment is based on the fact that the NoAC Recon has more slices {recon.shape[0]} than the attenuation map {mu_map.shape[0]}" cm = mu_map.shape[0] // 2 left = cm right = mu_map.shape[0] - cm cr = recon.shape[0] // 2 recon = recon[(cr - left):(cr + right)] return recon, mu_map __all__ = [MuMapDataset.__name__] if __name__ == "__main__": dataset = MuMapDataset("data/tmp") print(f"Images: {len(dataset)}") import cv2 as cv wname = "Images" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1024, 512) space = np.full((128, 10), 239, np.uint8) def to_grayscale(img: np.ndarray, min_val=None, max_val=None): if min_val is None: min_val = img.min() if max_val is None: max_val = img.max() _img = (img - min_val) / (max_val - min_val) _img = (_img * 255).astype(np.uint8) return _img for i in range(len(dataset)): ir = 0 im = 0 recon, mu_map = dataset[i] print(f"{i+1}/{len(dataset)} - {recon.shape} - {mu_map.shape}") to_show = np.hstack( ( to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()), space, to_grayscale(mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()), ) ) cv.imshow(wname, to_show) key = cv.waitKey(100) while True: ir = (ir + 1) % recon.shape[0] im = (im + 1) % mu_map.shape[0] to_show = np.hstack( ( to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()), space, to_grayscale( mu_map[im], min_val=mu_map.min(), max_val=mu_map.max() ), ) ) cv.imshow(wname, to_show) key = cv.waitKey(100) if key == ord("n"): break if key == ord("q"): exit(0)