import os from typing import List, Optional, Tuple import cv2 as cv import pandas as pd import numpy as np import torch from torch.utils.data import Dataset from mu_map.data.prepare import headers from mu_map.data.remove_bed import ( DEFAULT_BED_CONTOURS_FILENAME, load_contours, remove_bed, ) from mu_map.data.review_mu_map import discard_slices from mu_map.data.split import split_csv from mu_map.dataset.transform import Transform from mu_map.dataset.util import align_images from mu_map.file.dicom import load_dcm_img from mu_map.logging import get_logger class MuMapDataset(Dataset): """ A dataset to map reconstructions to attenuation maps (mu maps). The dataset is lazy. This means that that dataset creation is fast and images are only read into memory when their first accessed. Thus, after the first iteration, accessing images becomes a lot faster. """ def __init__( self, dataset_dir: str, csv_file: str = "meta.csv", split_file: str = "split.csv", split_name: Optional[str] = None, images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, align: bool = True, scatter_correction: bool = False, transform_normalization: Transform = Transform(), transform_augmentation: Transform = Transform(), logger=None, ): """ Create a new mu map dataset. Parameters ---------- dataset_dir: str directory of the dataset to be loaded csv_file: str name of the csv file in the dataset directory containing meta information (created by mu_map.data.prepare) split_file: str csv file defining a split of the dataset in train/validation/test (created by mu_map.data.split) split_name: str, optional the name of the split which is loaded images_dir: str directory under `dataset_dir` containing the actual images in DICOM format bed_contours_file: str, optional json file containing contours around the bed for each mu map (see mu_map.data.remove_bed) discard_mu_map_slices: bool remove defective slices from mu maps (have to be labeled by mu_map.data.review_mu_map) align: bool center align reconstructions and mu maps scatter_correction: bool use scatter corrected reconstructions transform_normalization: Transform transform used for image normalization which is applied once when the image is loaded transform_augmentation: Transform transform used for augmentation which is applied every time `__getitem__` is called logger: Logger, optional """ super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) self.split_file = os.path.join(dataset_dir, split_file) self.transform_normalization = transform_normalization self.transform_augmentation = transform_augmentation self.logger = ( logger if logger is not None else get_logger(name=MuMapDataset.__name__) ) self.bed_contours_file = ( os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None ) self.bed_contours = ( load_contours(self.bed_contours_file) if bed_contours_file else None ) # read CSV file and from that access DICOM files self.split_name = split_name self.table = pd.read_csv(self.csv_file) if split_name: self.table = split_csv(self.table, self.split_file)[split_name] self.table[headers.id] = self.table[headers.id].apply(int) self.discard_mu_map_slices = discard_mu_map_slices self.align = align self.scatter_correction = scatter_correction self.header_recon = ( headers.file_recon_nac_sc if self.scatter_correction else headers.file_recon_nac_nsc ) self.reconstructions = {} self.mu_maps = {} def copy(self, split_name: str, **kwargs) -> MuMapDataset: """ Create a copy of the dataset and modify parameters. Parameters ---------- split_name: str the split which with which the copy is created kwargs: Modify parameters by name. Currently, only `bed_contours_file` is supported. """ if "bed_contours_file" not in kwargs: kwargs["bed_contours_file"] = os.path.basename(self.bed_contours_file) return MuMapDataset( dataset_dir=self.dir, csv_file=os.path.basename(self.csv_file), split_file=os.path.basename(self.split_file), split_name=split_name, images_dir=os.path.basename(self.dir_images), bed_contours_file=kwargs["bed_contours_file"], discard_mu_map_slices=self.discard_mu_map_slices, align=self.align, scatter_correction=self.scatter_correction, transform_normalization=self.transform_normalization, transform_augmentation=self.transform_augmentation, logger=self.logger, ) def load_image(self, _id: int): """ Load an image into memory. This function also performs all of the pre-processing (discard slices, remove bed, alignment ...). Afterwards, the reconstruction and mu map are available from the local dicts: `self.reconstructions` and `self.mu_maps.` Parameters ---------- _id: int the id of the image to be loaded """ row = self.table[self.table[headers.id] == _id].iloc[0] _id = row[headers.id] mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) mu_map = load_dcm_img(mu_map_file, direction=1) recon_file = os.path.join(self.dir_images, row[self.header_recon]) recon = load_dcm_img(recon_file, direction=1) if self.discard_mu_map_slices: mu_map, recon = discard_slices(row, mu_map, recon) if self.bed_contours: if _id in self.bed_contours: bed_contour = self.bed_contours[_id] mu_map = remove_bed(mu_map, bed_contour) else: logger.warning(f"Could not find bed contour for id {_id}") if self.align: recon, mu_map = align_images(recon, mu_map) mu_map = mu_map.astype(np.float32) mu_map = torch.from_numpy(mu_map) mu_map = mu_map.unsqueeze(dim=0) recon = recon.astype(np.float32) recon = torch.from_numpy(recon) recon = recon.unsqueeze(dim=0) recon, mu_map = self.transform_normalization(recon, mu_map) self.mu_maps[_id] = mu_map self.reconstructions[_id] = recon def pre_load_images(self): """ Load all images into memory. """ for _id in self.table[headers.id]: self.load_image(_id) def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Get a reconstruction mu map pair by index. This method retrieves the image id of this index and call `get_item_by_id`. Parameters ---------- index: int """ row = self.table.iloc[index] _id = row[headers.id] return self.get_item_by_id(_id) def get_item_by_id(self, _id: int) -> Tuple[torch.Tensor, torch.Tensor]: """ Get a reconstruction and mu map pair by their id. This methods loads the images of not yet in memory and applies the augmentation transform before returning them. Parameters ---------- _id: int Returns ------- Tuple[torch.Tensor, torch.Tensor] a pair of a reconstruction and the according mu map """ if _id not in self.reconstructions: self.load_image(_id) recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] recon, mu_map = self.transform_augmentation(recon, mu_map) return recon, mu_map def __len__(self) -> int: """ Get the number of elements in this dataset. """ return len(self.table) __all__ = [MuMapDataset.__name__] def main(dataset: MuMapDataset, ids: Optional[List[int]] = None, paused: bool=False): """ Display reconstructions and mu maps in a dataset. Parameters ---------- dataset: MuMapDataset the dataset of which elements are displayed ids: list of int, optional only display these ids paused: bool start display in paused mode """ from mu_map.util import to_grayscale, COLOR_WHITE wname = "Dataset" cv.namedWindow(wname, cv.WINDOW_NORMAL) cv.resizeWindow(wname, 1600, 900) space = np.full((1024, 10), 239, np.uint8) timeout_paused = 0 timeout_running = 1000 // 15 timeout = timeout_paused if paused else timeout_running def to_display_image(image, _slice): _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA) _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}" _image = cv.putText( _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3 ) return _image def combine_images(images, slices): image_1 = to_display_image(images[0], slices[0]) image_2 = to_display_image(images[1], slices[1]) image_1 = image_1.repeat(3).reshape((*image_1.shape, 3)) image_2 = image_2.repeat(3).reshape((*image_2.shape, 3)) image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO) image_3_1 = image_2.copy() image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0) space = np.full((image_1.shape[0], 10, 3), 239, np.uint8) return np.hstack((image_1, space, image_3, space, image_2)) for i in range(len(dataset)): ir = 0 im = 0 row = dataset.table.iloc[i] _id = row[headers.id] if ids is not None and _id not in ids: continue recon, mu_map = dataset[i] recon = recon.squeeze().numpy() mu_map = mu_map.squeeze().numpy() print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r") cv.imshow(wname, combine_images((recon, mu_map), (ir, im))) key = cv.waitKey(timeout) running = 0 while True: ir = (ir + 1) % recon.shape[0] im = (im + 1) % mu_map.shape[0] to_show = combine_images((recon, mu_map), (ir, im)) cv.imshow(wname, to_show) key = cv.waitKey(timeout) if key == ord("n"): break elif key == ord("q"): exit(0) elif key == ord("p"): timeout = timeout_paused if timeout > 0 else timeout_running elif key == 82: # up arrow key ir = ir - 1 continue elif key == 83: # right arrow key im = im - 1 continue elif key == 81: # left arrow key im = im - 1 ir = max(ir - 2, 0) elif key == 84: # down arrow key ir = ir - 1 im = max(im - 2, 0) elif key == ord("s"): cv.imwrite(f"{running:03d}.png", to_show) running += 1 if __name__ == "__main__": import argparse from mu_map.dataset.transform import PadCropTranform from mu_map.logging import add_logging_args, get_logger_by_args parser = argparse.ArgumentParser( description="Visualize the images of a MuMapDataset", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dataset_dir", type=str, help="the directory from which the dataset is loaded" ) parser.add_argument( "--split", type=str, choices=["train", "validation", "test"], help="choose the split of the data for the dataset", ) parser.add_argument( "--unaligned", action="store_true", help="do not perform center alignment of reconstruction an mu-map slices", ) parser.add_argument( "--show_bed", action="store_true", help="do not remove the bed contour from the mu map", ) parser.add_argument( "--full_mu_map", action="store_true", help="do not remove broken slices of the mu map", ) parser.add_argument( "--ids", type=int, nargs="*", help="only display certain ids", ) parser.add_argument( "--paused", action="store_true", help="start in paused mode", ) parser.add_argument( "--pad_crop", type=int, help="pad crop images to this size", ) add_logging_args(parser, defaults={"--loglevel": "DEBUG"}) args = parser.parse_args() align = not args.unaligned discard_mu_map_slices = not args.full_mu_map bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME logger = get_logger_by_args(args) transform_normalization = ( PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform() ) dataset = MuMapDataset( args.dataset_dir, align=align, discard_mu_map_slices=discard_mu_map_slices, bed_contours_file=bed_contours_file, split_name=args.split, transform_normalization=transform_normalization, logger=logger, ) main(dataset, args.ids, paused=args.paused)