import os from typing import Optional import cv2 as cv import pandas as pd import pydicom import numpy as np import torch from import Dataset from import headers from import DEFAULT_BED_CONTOURS_FILENAME, load_contours from import discard_slices def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray: """ Align one image to another on the first axis (z-axis). It is assumed that the second image has less slices than the first. Then, the first image is shortened in a way that the centers of both images lie on top of each other. :param image_1: the image to be aligned :param image_2: the image to which image_1 is aligned :return: the aligned image_1 """ assert ( image_1.shape[0] > image_2.shape[0] ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}" # central slice of image 2 c_2 = image_2.shape[0] // 2 # image to the left and right of the center left = c_2 right = image_2.shape[0] - c_2 # central slice of image 1 c_1 = image_1.shape[0] // 2 # select center and same amount to the left/right as image_2 return image_1[(c_1 - left) : (c_1 + right)] class MuMapDataset(Dataset): def __init__( self, dataset_dir: str, csv_file: str = "meta.csv", images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, align: bool = True, ): super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) self.bed_contours_file = ( os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None ) self.bed_contours = ( load_contours(self.bed_contours_file) if bed_contours_file else None ) # read CSV file and from that access DICOM files self.table = pd.read_csv(self.csv_file) self.table["id"] = self.table["id"].apply(int) self.discard_mu_map_slices = discard_mu_map_slices self.align = align self.reconstructions = {} self.mu_maps = {} self.pre_load_images() def pre_load_images(self): print("Pre-loading images ...", end="\r") for i in range(len(self.table)): row = self.table.iloc[i] _id = row["id"] mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) mu_map = pydicom.dcmread(mu_map_file).pixel_array if self.discard_mu_map_slices: mu_map = discard_slices(row, mu_map) if self.bed_contours: bed_contour = self.bed_contours[row["id"]] for i in range(mu_map.shape[0]): mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1) recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc]) recon = pydicom.dcmread(recon_file).pixel_array if self.align: recon = align_images(recon, mu_map) mu_map = mu_map.astype(np.float32) mu_map = torch.from_numpy(mu_map) mu_map = mu_map.unsqueeze(dim=0) self.mu_maps[_id] = mu_map recon = recon.astype(np.float32) recon = torch.from_numpy(recon) recon = recon.unsqueeze(dim=0) self.reconstructions[_id] = recon print("Pre-loading images done!") def __getitem__(self, index: int): row = self.table.iloc[index] _id = row["id"] recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] # recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc]) # mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) # recon = pydicom.dcmread(recon_file).pixel_array # mu_map = pydicom.dcmread(mu_map_file).pixel_array # if self.discard_mu_map_slices: # mu_map = discard_slices(row, mu_map) # if self.bed_contours: # bed_contour = self.bed_contours[row["id"]] # for i in range(mu_map.shape[0]): # mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1) # if self.align: # recon = align_images(recon, mu_map) return recon, mu_map def __len__(self): return len(self.table) __all__ = [MuMapDataset.__name__] if __name__ == "__main__": import argparse from mu_map.util import to_grayscale, COLOR_WHITE parser = argparse.ArgumentParser( description="Visualize the images of a MuMapDataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dataset_dir", type=str, help="the directory from which the dataset is loaded" ) parser.add_argument( "--unaligned", action="store_true", help="do not perform center alignment of reconstruction an mu-map slices", ) parser.add_argument( "--show_bed", action="store_true", help="do not remove the bed contour from the mu map", ) parser.add_argument( "--full_mu_map", action="store_true", help="do not remove broken slices of the mu map", ) args = parser.parse_args() align = not args.unaligned discard_mu_map_slices = not args.full_mu_map bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME dataset = MuMapDataset( args.dataset_dir, align=align, discard_mu_map_slices=discard_mu_map_slices, bed_contours_file=bed_contours_file, ) wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) space = np.full((1024, 10), 239, np.uint8) timeout = 100 def to_display_image(image, _slice): _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" _image = cv.putText( _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 ) return _image def combine_images(images, slices): image_1 = to_display_image(images[0], slices[0]) image_2 = to_display_image(images[1], slices[1]) space = np.full((image_1.shape[0], 10), 239, np.uint8) return np.hstack((image_1, space, image_2)) for i in range(len(dataset)): ir = 0 im = 0 recon, mu_map = dataset[i] recon = recon.squeeze().numpy() mu_map = mu_map.squeeze().numpy() print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r") cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) key = cv.waitKey(timeout) running = 0 while True: ir = (ir + 1) % recon.shape[0] im = (im + 1) % mu_map.shape[0] to_show = combine_images((recon, mu_map), (ir, im)) cv.imshow(wname, to_show) key = cv.waitKey(timeout) if key == ord("n"): break elif key == ord("q"): exit(0) elif key == ord("p"): timeout = 0 if timeout > 0 else 100 elif key == 83: # right arrow key continue elif key == 81: # left arrow key ir = max(ir - 2, 0) im = max(im - 2, 0) elif key == ord("s"): cv.imwrite(f"{running:03d}.png", to_show) running += 1