import os from typing import Optional, Tuple import cv2 as cv import pandas as pd import pydicom import numpy as np import torch from import Dataset from import headers from import ( DEFAULT_BED_CONTOURS_FILENAME, load_contours, remove_bed, ) from import discard_slices from import split_csv from mu_map.dataset.transform import Transform from mu_map.dataset.util import align_images, load_dcm_img from mu_map.logging import get_logger class MuMapDataset(Dataset): def __init__( self, dataset_dir: str, csv_file: str = "meta.csv", split_file: str = "split.csv", split_name: str = None, images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, align: bool = True, scatter_correction: bool = False, transform_normalization: Transform = Transform(), transform_augmentation: Transform = Transform(), logger=None, ): super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) self.split_file = os.path.join(dataset_dir, split_file) self.transform_normalization = transform_normalization self.transform_augmentation = transform_augmentation self.logger = logger if logger is not None else get_logger() self.bed_contours_file = ( os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None ) self.bed_contours = ( load_contours(self.bed_contours_file) if bed_contours_file else None ) # read CSV file and from that access DICOM files self.table = pd.read_csv(self.csv_file) if split_name: self.table = split_csv(self.table, self.split_file)[split_name] self.table[] = self.table[].apply(int) self.discard_mu_map_slices = discard_mu_map_slices self.align = align self.scatter_correction = scatter_correction self.header_recon = ( headers.file_recon_nac_sc if self.scatter_correction else headers.file_recon_nac_nsc ) self.reconstructions = {} self.mu_maps = {} def load_image(self, _id: int): row = self.table[self.table[] == _id].iloc[0] _id = row[] mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) mu_map = load_dcm_img(mu_map_file) if self.discard_mu_map_slices: mu_map = discard_slices(row, mu_map) if self.bed_contours: if _id in self.bed_contours: bed_contour = self.bed_contours[_id] mu_map = remove_bed(mu_map, bed_contour) else: logger.warning(f"Could not find bed contour for id {_id}") recon_file = os.path.join(self.dir_images, row[self.header_recon]) recon = load_dcm_img(recon_file) if self.align: recon, mu_map = align_images(recon, mu_map) mu_map = mu_map.astype(np.float32) mu_map = torch.from_numpy(mu_map) mu_map = mu_map.unsqueeze(dim=0) recon = recon.astype(np.float32) recon = torch.from_numpy(recon) recon = recon.unsqueeze(dim=0) recon, mu_map = self.transform_normalization(recon, mu_map) self.mu_maps[_id] = mu_map self.reconstructions[_id] = recon def __getitem__(self, index: int): row = self.table.iloc[index] _id = row[] return self.getitem_by_id(_id) def getitem_by_id(self, _id: int): if _id not in self.reconstructions: self.load_image(_id) recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] recon, mu_map = self.transform_augmentation(recon, mu_map) return recon, mu_map def __len__(self): return len(self.table) __all__ = [MuMapDataset.__name__] def main(dataset, ids, paused=False): from mu_map.util import to_grayscale, COLOR_WHITE wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) space = np.full((1024, 10), 239, np.uint8) TIMEOUT_PAUSED = 0 TIMEOUT_RUNNING = 1000 // 15 timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING def to_display_image(image, _slice): _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" _image = cv.putText( _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 ) return _image def combine_images(images, slices): image_1 = to_display_image(images[0], slices[0]) image_2 = to_display_image(images[1], slices[1]) image_1 = image_1.repeat(3).reshape((*image_1.shape, 3)) image_2 = image_2.repeat(3).reshape((*image_2.shape, 3)) image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO) image_3_1 = image_2.copy() image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0) space = np.full((image_1.shape[0], 10, 3), 239, np.uint8) return np.hstack((image_1, space, image_3, space, image_2)) for i in range(len(dataset)): ir = 0 im = 0 row = dataset.table.iloc[i] _id = row[] if ids is not None and _id not in ids: continue recon, mu_map = dataset[i] recon = recon.squeeze().numpy() mu_map = mu_map.squeeze().numpy() print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r") cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) key = cv.waitKey(timeout) running = 0 while True: ir = (ir + 1) % recon.shape[0] im = (im + 1) % mu_map.shape[0] to_show = combine_images((recon, mu_map), (ir, im)) cv.imshow(wname, to_show) key = cv.waitKey(timeout) if key == ord("n"): break elif key == ord("q"): exit(0) elif key == ord("p"): timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING elif key == 82: # up arrow key ir = ir - 1 continue elif key == 83: # right arrow key im = im - 1 continue elif key == 81: # left arrow key im = im - 1 ir = max(ir - 2, 0) elif key == 84: # down arrow key ir = ir - 1 im = max(im - 2, 0) elif key == ord("s"): cv.imwrite(f"{running:03d}.png", to_show) running += 1 if __name__ == "__main__": import argparse from mu_map.dataset.transform import PadCropTranform from mu_map.logging import add_logging_args, get_logger_by_args parser = argparse.ArgumentParser( description="Visualize the images of a MuMapDataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dataset_dir", type=str, help="the directory from which the dataset is loaded" ) parser.add_argument( "--split", type=str, choices=["train", "validation", "test"], help="choose the split of the data for the dataset", ) parser.add_argument( "--unaligned", action="store_true", help="do not perform center alignment of reconstruction an mu-map slices", ) parser.add_argument( "--show_bed", action="store_true", help="do not remove the bed contour from the mu map", ) parser.add_argument( "--full_mu_map", action="store_true", help="do not remove broken slices of the mu map", ) parser.add_argument( "--ids", type=int, nargs="*", help="only display certain ids", ) parser.add_argument( "--paused", action="store_true", help="start in paused mode", ) parser.add_argument( "--pad_crop", type=int, help="pad crop images to this size", ) add_logging_args(parser, defaults={"--loglevel": "DEBUG"}) args = parser.parse_args() align = not args.unaligned discard_mu_map_slices = not args.full_mu_map bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME logger = get_logger_by_args(args) transform_normalization = ( PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform() ) dataset = MuMapDataset( args.dataset_dir, align=align, discard_mu_map_slices=discard_mu_map_slices, bed_contours_file=bed_contours_file, split_name=args.split, transform_normalization=transform_normalization, logger=logger, ) main(dataset, args.ids, paused=args.paused)