diff --git a/mu_map/data/datasets.py b/mu_map/data/datasets.py index e1aa99d458a5d97c9b7132d98c41550b5cb3ec41..f591ce63a75c509fffa683e7e9f05e6f77ca4222 100644 --- a/mu_map/data/datasets.py +++ b/mu_map/data/datasets.py @@ -2,20 +2,110 @@ import os import pandas as pd import pydicom +import numpy as np from torch.utils.data import Dataset -class MuMapDataset(Dataset): - def __init__(self): +class MuMapDataset(Dataset): + def __init__( + self, dataset_dir: str, csv_file: str = "meta.csv", images_dir: str = "images" + ): super().__init__() + self.dir = dataset_dir + self.dir_images = os.path.join(dataset_dir, images_dir) + self.csv_file = os.path.join(dataset_dir, csv_file) + + self.table = pd.read_csv(self.csv_file) # read csv file and from that access dicom files - def __getitem__(self, index): - pass + def __getitem__(self, index: int): + row = self.table.iloc[index] + + recon_file = os.path.join(self.dir_images, row["file_recon_no_ac"]) + mu_map_file = os.path.join(self.dir_images, row["file_mu_map"]) + + recon = pydicom.dcmread(recon_file).pixel_array + mu_map = pydicom.dcmread(mu_map_file).pixel_array + recon, mu_map = self.align(recon, mu_map) + + return recon, mu_map def __len__(self): - pass + return len(self.table) + + def align(self, recon, mu_map): + assert recon.shape[0] > mu_map.shape[0], f"Alignment is based on the fact that the NoAC Recon has more slices {recon.shape[0]} than the attenuation map {mu_map.shape[0]}" + + cm = mu_map.shape[0] // 2 + + left = cm + right = mu_map.shape[0] - cm + + cr = recon.shape[0] // 2 + recon = recon[(cr - left):(cr + right)] + return recon, mu_map __all__ = [MuMapDataset.__name__] + +if __name__ == "__main__": + dataset = MuMapDataset("data/tmp") + print(f"Images: {len(dataset)}") + + import cv2 as cv + + wname = "Images" + cv.namedWindow(wname, cv.WINDOW_NORMAL) + cv.resizeWindow(wname, 1024, 512) + space = np.full((128, 10), 239, np.uint8) + + def to_grayscale(img: np.ndarray, min_val=None, max_val=None): + if min_val is None: + min_val = img.min() + + if max_val is None: + max_val = img.max() + + _img = (img - min_val) / (max_val - min_val) + _img = (_img * 255).astype(np.uint8) + return _img + + for i in range(len(dataset)): + ir = 0 + im = 0 + + recon, mu_map = dataset[i] + print(f"{i+1}/{len(dataset)} - {recon.shape} - {mu_map.shape}") + + to_show = np.hstack( + ( + to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()), + space, + to_grayscale(mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()), + ) + ) + cv.imshow(wname, to_show) + key = cv.waitKey(100) + + while True: + ir = (ir + 1) % recon.shape[0] + im = (im + 1) % mu_map.shape[0] + + to_show = np.hstack( + ( + to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()), + space, + to_grayscale( + mu_map[im], min_val=mu_map.min(), max_val=mu_map.max() + ), + ) + ) + cv.imshow(wname, to_show) + + key = cv.waitKey(100) + + if key == ord("n"): + break + if key == ord("q"): + exit(0)