import os from typing import Optional, Tuple import cv2 as cv import pandas as pd import pydicom import numpy as np import torch from torch.utils.data import Dataset from mu_map.data.prepare import headers from mu_map.data.remove_bed import ( DEFAULT_BED_CONTOURS_FILENAME, load_contours, remove_bed, ) from mu_map.data.review_mu_map import discard_slices from mu_map.data.split import split_csv from mu_map.dataset.transform import Transform from mu_map.dataset.util import align_images, load_dcm_img from mu_map.logging import get_logger class MuMapDataset(Dataset): def __init__( self, dataset_dir: str, csv_file: str = "meta.csv", split_file: str = "split.csv", split_name: str = None, images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, align: bool = True, scatter_correction: bool = False, transform_normalization: Transform = Transform(), transform_augmentation: Transform = Transform(), logger=None, ): super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) self.split_file = os.path.join(dataset_dir, split_file) self.transform_normalization = transform_normalization self.transform_augmentation = transform_augmentation self.logger = logger if logger is not None else get_logger() self.bed_contours_file = ( os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None ) self.bed_contours = ( load_contours(self.bed_contours_file) if bed_contours_file else None ) # read CSV file and from that access DICOM files self.table = pd.read_csv(self.csv_file) if split_name: self.table = split_csv(self.table, self.split_file)[split_name] self.table[headers.id] = self.table[headers.id].apply(int) self.discard_mu_map_slices = discard_mu_map_slices self.align = align self.scatter_correction = scatter_correction self.header_recon = ( headers.file_recon_nac_sc if self.scatter_correction else headers.file_recon_nac_nsc ) self.reconstructions = {} self.mu_maps = {} def load_image(self, _id: int): row = self.table[self.table[headers.id] == _id].iloc[0] _id = row[headers.id] mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) mu_map = load_dcm_img(mu_map_file) if self.discard_mu_map_slices: mu_map = discard_slices(row, mu_map) if self.bed_contours: if _id in self.bed_contours: bed_contour = self.bed_contours[_id] mu_map = remove_bed(mu_map, bed_contour) else: logger.warning(f"Could not find bed contour for id {_id}") recon_file = os.path.join(self.dir_images, row[self.header_recon]) recon = load_dcm_img(recon_file) if self.align: recon, mu_map = align_images(recon, mu_map) mu_map = mu_map.astype(np.float32) mu_map = torch.from_numpy(mu_map) mu_map = mu_map.unsqueeze(dim=0) recon = recon.astype(np.float32) recon = torch.from_numpy(recon) recon = recon.unsqueeze(dim=0) recon, mu_map = self.transform_normalization(recon, mu_map) self.mu_maps[_id] = mu_map self.reconstructions[_id] = recon def __getitem__(self, index: int): row = self.table.iloc[index] _id = row[headers.id] return self.getitem_by_id(_id) def getitem_by_id(self, _id: int): if _id not in self.reconstructions: self.load_image(_id) recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] recon, mu_map = self.transform_augmentation(recon, mu_map) return recon, mu_map def __len__(self): return len(self.table) __all__ = [MuMapDataset.__name__] def main(dataset, ids, paused=False): from mu_map.util import to_grayscale, COLOR_WHITE wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) space = np.full((1024, 10), 239, np.uint8) TIMEOUT_PAUSED = 0 TIMEOUT_RUNNING = 1000 // 15 timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING def to_display_image(image, _slice): _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" _image = cv.putText( _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 ) return _image def combine_images(images, slices): image_1 = to_display_image(images[0], slices[0]) image_2 = to_display_image(images[1], slices[1]) image_1 = image_1.repeat(3).reshape((*image_1.shape, 3)) image_2 = image_2.repeat(3).reshape((*image_2.shape, 3)) image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO) image_3_1 = image_2.copy() image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0) space = np.full((image_1.shape[0], 10, 3), 239, np.uint8) return np.hstack((image_1, space, image_3, space, image_2)) for i in range(len(dataset)): ir = 0 im = 0 row = dataset.table.iloc[i] _id = row[headers.id] if ids is not None and _id not in ids: continue recon, mu_map = dataset[i] recon = recon.squeeze().numpy() mu_map = mu_map.squeeze().numpy() print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r") cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) key = cv.waitKey(timeout) running = 0 while True: ir = (ir + 1) % recon.shape[0] im = (im + 1) % mu_map.shape[0] to_show = combine_images((recon, mu_map), (ir, im)) cv.imshow(wname, to_show) key = cv.waitKey(timeout) if key == ord("n"): break elif key == ord("q"): exit(0) elif key == ord("p"): timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING elif key == 82: # up arrow key ir = ir - 1 continue elif key == 83: # right arrow key im = im - 1 continue elif key == 81: # left arrow key im = im - 1 ir = max(ir - 2, 0) elif key == 84: # down arrow key ir = ir - 1 im = max(im - 2, 0) elif key == ord("s"): cv.imwrite(f"{running:03d}.png", to_show) running += 1 if __name__ == "__main__": import argparse from mu_map.dataset.transform import PadCropTranform from mu_map.logging import add_logging_args, get_logger_by_args parser = argparse.ArgumentParser( description="Visualize the images of a MuMapDataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dataset_dir", type=str, help="the directory from which the dataset is loaded" ) parser.add_argument( "--split", type=str, choices=["train", "validation", "test"], help="choose the split of the data for the dataset", ) parser.add_argument( "--unaligned", action="store_true", help="do not perform center alignment of reconstruction an mu-map slices", ) parser.add_argument( "--show_bed", action="store_true", help="do not remove the bed contour from the mu map", ) parser.add_argument( "--full_mu_map", action="store_true", help="do not remove broken slices of the mu map", ) parser.add_argument( "--ids", type=int, nargs="*", help="only display certain ids", ) parser.add_argument( "--paused", action="store_true", help="start in paused mode", ) parser.add_argument( "--pad_crop", type=int, help="pad crop images to this size", ) add_logging_args(parser, defaults={"--loglevel": "DEBUG"}) args = parser.parse_args() align = not args.unaligned discard_mu_map_slices = not args.full_mu_map bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME logger = get_logger_by_args(args) transform_normalization = ( PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform() ) dataset = MuMapDataset( args.dataset_dir, align=align, discard_mu_map_slices=discard_mu_map_slices, bed_contours_file=bed_contours_file, split_name=args.split, transform_normalization=transform_normalization, logger=logger, ) main(dataset, args.ids, paused=args.paused)