import os from typing import Optional import cv2 as cv import pandas as pd import pydicom import numpy as np import torch from torch.utils.data import Dataset from mu_map.data.prepare import headers from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours from mu_map.data.review_mu_map import discard_slices from mu_map.data.split import split_csv from mu_map.dataset.transform import Transform from mu_map.logging import get_logger """ Since DICOM images only allow images stored in short integer format, the Siemens scanner software multiplies values by a factor before storing so that no precision is lost. The scale can be found in this private DICOM tag. """ DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038 def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray: """ Align one image to another on the first axis (z-axis). It is assumed that the second image has less slices than the first. Then, the first image is shortened in a way that the centers of both images lie on top of each other. :param image_1: the image to be aligned :param image_2: the image to which image_1 is aligned :return: the aligned image_1 """ assert ( image_1.shape[0] > image_2.shape[0] ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}" # central slice of image 2 c_2 = image_2.shape[0] // 2 # image to the left and right of the center left = c_2 right = image_2.shape[0] - c_2 # central slice of image 1 c_1 = image_1.shape[0] // 2 # select center and same amount to the left/right as image_2 return image_1[(c_1 - left) : (c_1 + right)] class MuMapDataset(Dataset): def __init__( self, dataset_dir: str, csv_file: str = "meta.csv", split_file: str = "split.csv", split_name: str = None, images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, align: bool = True, transform_normalization: Transform = Transform(), transform_augmentation: Transform = Transform(), logger=None, ): super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) self.split_file = os.path.join(dataset_dir, split_file) self.transform_normalization = transform_normalization self.transform_augmentation = transform_augmentation self.logger = logger if logger is not None else get_logger() self.bed_contours_file = ( os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None ) self.bed_contours = ( load_contours(self.bed_contours_file) if bed_contours_file else None ) # read CSV file and from that access DICOM files self.table = pd.read_csv(self.csv_file) if split_name: self.table = split_csv(self.table, self.split_file)[split_name] self.table["id"] = self.table["id"].apply(int) self.discard_mu_map_slices = discard_mu_map_slices self.align = align self.reconstructions = {} self.mu_maps = {} self.pre_load_images() def pre_load_images(self): self.logger.debug("Pre-loading images ...") for i in range(len(self.table)): row = self.table.iloc[i] _id = row["id"] mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) mu_map = pydicom.dcmread(mu_map_file) mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value if self.discard_mu_map_slices: mu_map = discard_slices(row, mu_map) if self.bed_contours: bed_contour = self.bed_contours[row["id"]] for i in range(mu_map.shape[0]): mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1) recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc]) recon = pydicom.dcmread(recon_file) recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value if self.align: recon = align_images(recon, mu_map) mu_map = mu_map.astype(np.float32) mu_map = torch.from_numpy(mu_map) mu_map = mu_map.unsqueeze(dim=0) recon = recon.astype(np.float32) recon = torch.from_numpy(recon) recon = recon.unsqueeze(dim=0) recon, mu_map = self.transform_normalization(recon, mu_map) self.mu_maps[_id] = mu_map self.reconstructions[_id] = recon self.logger.debug("Pre-loading images done!") def __getitem__(self, index: int): row = self.table.iloc[index] _id = row["id"] recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] recon, mu_map = self.transform_augmentation(recon, mu_map) return recon, mu_map def __len__(self): return len(self.table) __all__ = [MuMapDataset.__name__] def main(dataset): from mu_map.util import to_grayscale, COLOR_WHITE wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) space = np.full((1024, 10), 239, np.uint8) timeout = 100 def to_display_image(image, _slice): _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" _image = cv.putText( _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 ) return _image def combine_images(images, slices): image_1 = to_display_image(images[0], slices[0]) image_2 = to_display_image(images[1], slices[1]) space = np.full((image_1.shape[0], 10), 239, np.uint8) return np.hstack((image_1, space, image_2)) for i in range(len(dataset)): ir = 0 im = 0 recon, mu_map = dataset[i] recon = recon.squeeze().numpy() mu_map = mu_map.squeeze().numpy() print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r") cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) key = cv.waitKey(timeout) running = 0 while True: ir = (ir + 1) % recon.shape[0] im = (im + 1) % mu_map.shape[0] to_show = combine_images((recon, mu_map), (ir, im)) cv.imshow(wname, to_show) key = cv.waitKey(timeout) if key == ord("n"): break elif key == ord("q"): exit(0) elif key == ord("p"): timeout = 0 if timeout > 0 else 100 elif key == 83: # right arrow key continue elif key == 81: # left arrow key ir = max(ir - 2, 0) im = max(im - 2, 0) elif key == ord("s"): cv.imwrite(f"{running:03d}.png", to_show) running += 1 if __name__ == "__main__": import argparse from mu_map.logging import add_logging_args, get_logger_by_args parser = argparse.ArgumentParser( description="Visualize the images of a MuMapDataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dataset_dir", type=str, help="the directory from which the dataset is loaded" ) parser.add_argument( "--split", type=str, choices=["train", "validation", "test"], help="choose the split of the data for the dataset", ) parser.add_argument( "--unaligned", action="store_true", help="do not perform center alignment of reconstruction an mu-map slices", ) parser.add_argument( "--show_bed", action="store_true", help="do not remove the bed contour from the mu map", ) parser.add_argument( "--full_mu_map", action="store_true", help="do not remove broken slices of the mu map", ) add_logging_args(parser, defaults={"--loglevel": "DEBUG"}) args = parser.parse_args() align = not args.unaligned discard_mu_map_slices = not args.full_mu_map bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME logger = get_logger_by_args(args) dataset = MuMapDataset( args.dataset_dir, align=align, discard_mu_map_slices=discard_mu_map_slices, bed_contours_file=bed_contours_file, split_name=args.split, logger=logger, ) main(dataset)