diff --git a/mu_map/data/datasets.py b/mu_map/data/datasets.py index f591ce63a75c509fffa683e7e9f05e6f77ca4222..5fd89402cdad67a02d81dc1ddaae58639b17f5c8 100644 --- a/mu_map/data/datasets.py +++ b/mu_map/data/datasets.py @@ -6,9 +6,63 @@ import numpy as np from torch.utils.data import Dataset +HEADER_DISC_FIRST = "discard_first" +HEADER_DISC_LAST = "discard_last" + + +def discard_slices(row, μ_map): + """ + Discard slices based on the flags in the row of th according table. + The row is expected to contain the flags 'discard_first' and 'discard_last'. + + :param row: the row of meta configuration file of a dataset + :param μ_map: the μ_map + :return: the μ_map with according slices removed + """ + _res = μ_map + + if row[HEADER_DISC_FIRST]: + _res = _res[1:] + + if row[HEADER_DISC_LAST]: + _res = _res[:-1] + + return _res + + +def align_images(image_1: np.ndarray, image_2: np.ndarray): + """ + Align one image to another on the first axis (z-axis). + It is assumed that the second image has less slices than the first. + Then, the first image is shortened in a way that the centers of both images lie on top of each other. + + :param image_1: the image to be aligned + :param image_2: the image to which image_1 is aligned + :return: the aligned image_1 + """ + assert ( + image_1.shape[0] > image_2.shape[0] + ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}" + + # central slice of image 2 + c_2 = image_2.shape[0] // 2 + # image to the left and right of the center + left = c_2 + right = image_2.shape[0] - c_2 + + # central slice of image 1 + c_1 = image_1.shape[0] // 2 + # select center and same amount to the left/right as image_2 + return image_1[(c_1 - left) : (c_1 + right)] + + class MuMapDataset(Dataset): def __init__( - self, dataset_dir: str, csv_file: str = "meta.csv", images_dir: str = "images" + self, + dataset_dir: str, + csv_file: str = "meta.csv", + images_dir: str = "images", + discard_μ_map_slices: bool = True, ): super().__init__() @@ -16,8 +70,10 @@ class MuMapDataset(Dataset): self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) + # read CSV file and from that access DICOM files self.table = pd.read_csv(self.csv_file) - # read csv file and from that access dicom files + + self.discard_μ_map_slices = discard_μ_map_slices def __getitem__(self, index: int): row = self.table.iloc[index] @@ -27,25 +83,17 @@ class MuMapDataset(Dataset): recon = pydicom.dcmread(recon_file).pixel_array mu_map = pydicom.dcmread(mu_map_file).pixel_array - recon, mu_map = self.align(recon, mu_map) + + if self.discard_μ_map_slices: + mu_map = discard_slices(row, mu_map) + + recon = align_images(recon, mu_map) return recon, mu_map def __len__(self): return len(self.table) - def align(self, recon, mu_map): - assert recon.shape[0] > mu_map.shape[0], f"Alignment is based on the fact that the NoAC Recon has more slices {recon.shape[0]} than the attenuation map {mu_map.shape[0]}" - - cm = mu_map.shape[0] // 2 - - left = cm - right = mu_map.shape[0] - cm - - cr = recon.shape[0] // 2 - recon = recon[(cr - left):(cr + right)] - return recon, mu_map - __all__ = [MuMapDataset.__name__]