import os from typing import Optional import pandas as pd import pydicom import numpy as np from torch.utils.data import Dataset from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours from mu_map.data.review_mu_map import discard_slices def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray: """ Align one image to another on the first axis (z-axis). It is assumed that the second image has less slices than the first. Then, the first image is shortened in a way that the centers of both images lie on top of each other. :param image_1: the image to be aligned :param image_2: the image to which image_1 is aligned :return: the aligned image_1 """ assert ( image_1.shape[0] > image_2.shape[0] ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}" # central slice of image 2 c_2 = image_2.shape[0] // 2 # image to the left and right of the center left = c_2 right = image_2.shape[0] - c_2 # central slice of image 1 c_1 = image_1.shape[0] // 2 # select center and same amount to the left/right as image_2 return image_1[(c_1 - left) : (c_1 + right)] class MuMapDataset(Dataset): def __init__( self, dataset_dir: str, csv_file: str = "meta.csv", images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, ): super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) self.bed_contours_file = ( os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None ) self.bed_contours = ( load_contours(self.bed_contours_file) if bed_contours_file else None ) # read CSV file and from that access DICOM files self.table = pd.read_csv(self.csv_file) self.table["id"] = self.table["id"].apply(int) self.discard_mu_map_slices = discard_mu_map_slices def __getitem__(self, index: int): row = self.table.iloc[index] recon_file = os.path.join(self.dir_images, row["file_recon_no_ac"]) mu_map_file = os.path.join(self.dir_images, row["file_mu_map"]) recon = pydicom.dcmread(recon_file).pixel_array mu_map = pydicom.dcmread(mu_map_file).pixel_array if self.discard_mu_map_slices: mu_map = discard_slices(row, mu_map) if self.bed_contours: bed_contour = self.bed_contours[row["id"]] for i in range(mu_map.shape[0]): mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1) recon = align_images(recon, mu_map) return recon, mu_map def __len__(self): return len(self.table) __all__ = [MuMapDataset.__name__] if __name__ == "__main__": dataset = MuMapDataset("data/tmp") import cv2 as cv wname = "Images" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1024, 512) space = np.full((128, 10), 239, np.uint8) def to_grayscale(img: np.ndarray, min_val=None, max_val=None): if min_val is None: min_val = img.min() if max_val is None: max_val = img.max() _img = (img - min_val) / (max_val - min_val) _img = (_img * 255).astype(np.uint8) return _img for i in range(len(dataset)): ir = 0 im = 0 recon, mu_map = dataset[i] print(f"{i+1}/{len(dataset)} - {recon.shape} - {mu_map.shape}") to_show = np.hstack( ( to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()), space, to_grayscale(mu_map[im], min_val=mu_map.min(), max_val=mu_map.max()), ) ) cv.imshow(wname, to_show) key = cv.waitKey(100) while True: ir = (ir + 1) % recon.shape[0] im = (im + 1) % mu_map.shape[0] to_show = np.hstack( ( to_grayscale(recon[ir], min_val=recon.min(), max_val=recon.max()), space, to_grayscale( mu_map[im], min_val=mu_map.min(), max_val=mu_map.max() ), ) ) cv.imshow(wname, to_show) key = cv.waitKey(100) if key == ord("n"): break if key == ord("q"): exit(0)