Skip to content
Snippets Groups Projects
default.py 13.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    import os
    
    from typing import List, Optional, Tuple
    
    import cv2 as cv
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    import pandas as pd
    
    import numpy as np
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    from torch.utils.data import Dataset
    
    
    from mu_map.data.prepare import headers
    
    from mu_map.data.remove_bed import (
        DEFAULT_BED_CONTOURS_FILENAME,
        load_contours,
        remove_bed,
    )
    
    from mu_map.data.review_mu_map import discard_slices
    
    from mu_map.data.split import split_csv
    
    from mu_map.dataset.transform import Transform
    
    from mu_map.dataset.util import align_images
    from mu_map.file.dicom import load_dcm_img
    
    from mu_map.logging import get_logger
    
    class MuMapDataset(Dataset):
    
        """
        A dataset to map reconstructions to attenuation maps (mu maps).
    
        The dataset is lazy. This means that that dataset creation is
        fast and images are only read into memory when their first accessed.
        Thus, after the first iteration, accessing images becomes a lot
        faster.
        """
    
        def __init__(
    
            self,
            dataset_dir: str,
            csv_file: str = "meta.csv",
    
            split_file: str = "split.csv",
    
            split_name: Optional[str] = None,
    
            images_dir: str = "images",
    
            bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
    
            discard_mu_map_slices: bool = True,
    
            align: bool = True,
    
            scatter_correction: bool = False,
    
            transform_normalization: Transform = Transform(),
            transform_augmentation: Transform = Transform(),
            logger=None,
    
            """
            Create a new mu map dataset.
    
            Parameters
            ----------
            dataset_dir: str
                directory of the dataset to be loaded
            csv_file: str
                name of the csv file in the dataset directory containing meta information (created by mu_map.data.prepare)
            split_file: str
                csv file defining a split of the dataset in train/validation/test (created by mu_map.data.split)
            split_name: str, optional
                the name of the split which is loaded
            images_dir: str
                directory under `dataset_dir` containing the actual images in DICOM format
            bed_contours_file: str, optional
                json file containing contours around the bed for each mu map (see mu_map.data.remove_bed)
            discard_mu_map_slices: bool
                remove defective slices from mu maps (have to be labeled by mu_map.data.review_mu_map)
            align: bool
                center align reconstructions and mu maps
            scatter_correction: bool
                use scatter corrected reconstructions
            transform_normalization: Transform
                transform used for image normalization which is applied once when the image is loaded
            transform_augmentation: Transform
                transform used for augmentation which is applied every time `__getitem__` is called
            logger: Logger, optional 
            """
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            super().__init__()
    
    
            self.dir = dataset_dir
            self.dir_images = os.path.join(dataset_dir, images_dir)
            self.csv_file = os.path.join(dataset_dir, csv_file)
    
            self.split_file = os.path.join(dataset_dir, split_file)
    
    
            self.transform_normalization = transform_normalization
            self.transform_augmentation = transform_augmentation
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.logger = (
                logger if logger is not None else get_logger(name=MuMapDataset.__name__)
            )
    
            self.bed_contours_file = (
                os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
            )
            self.bed_contours = (
                load_contours(self.bed_contours_file) if bed_contours_file else None
            )
    
            # read CSV file and from that access DICOM files
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.split_name = split_name
    
            self.table = pd.read_csv(self.csv_file)
    
            if split_name:
                self.table = split_csv(self.table, self.split_file)[split_name]
    
            self.table[headers.id] = self.table[headers.id].apply(int)
    
            self.discard_mu_map_slices = discard_mu_map_slices
    
            self.align = align
    
            self.scatter_correction = scatter_correction
    
            self.header_recon = (
                headers.file_recon_nac_sc
                if self.scatter_correction
                else headers.file_recon_nac_nsc
            )
    
            self.reconstructions = {}
            self.mu_maps = {}
    
        def copy(self, split_name: str, **kwargs) -> MuMapDataset:
    
            """
            Create a copy of the dataset and modify parameters.
    
            Parameters
            ----------
            split_name: str
                the split which with which the copy is created
    
            kwargs:
                Modify parameters by name.
                Currently, only `bed_contours_file` is supported.
    
            if "bed_contours_file" not in kwargs:
                kwargs["bed_contours_file"] = os.path.basename(self.bed_contours_file)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            return MuMapDataset(
                dataset_dir=self.dir,
                csv_file=os.path.basename(self.csv_file),
                split_file=os.path.basename(self.split_file),
                split_name=split_name,
    
                images_dir=os.path.basename(self.dir_images),
    
                bed_contours_file=kwargs["bed_contours_file"],
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                discard_mu_map_slices=self.discard_mu_map_slices,
                align=self.align,
                scatter_correction=self.scatter_correction,
                transform_normalization=self.transform_normalization,
                transform_augmentation=self.transform_augmentation,
                logger=self.logger,
            )
    
    
        def load_image(self, _id: int):
    
            """
            Load an image into memory.
    
            This function also performs all of the pre-processing (discard slices, remove bed, alignment ...).
            Afterwards, the reconstruction and mu map are available from the local
            dicts: `self.reconstructions` and `self.mu_maps.`
    
            Parameters
            ----------
            _id: int
                the id of the image to be loaded
            """
    
            row = self.table[self.table[headers.id] == _id].iloc[0]
            _id = row[headers.id]
    
            mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
    
            mu_map = load_dcm_img(mu_map_file, direction=1)
    
            recon_file = os.path.join(self.dir_images, row[self.header_recon])
            recon = load_dcm_img(recon_file, direction=1)
    
    
            if self.discard_mu_map_slices:
    
                mu_map, recon = discard_slices(row, mu_map, recon)
    
            if self.bed_contours:
                if _id in self.bed_contours:
                    bed_contour = self.bed_contours[_id]
                    mu_map = remove_bed(mu_map, bed_contour)
                else:
                    logger.warning(f"Could not find bed contour for id {_id}")
    
            if self.align:
                recon, mu_map = align_images(recon, mu_map)
    
            mu_map = mu_map.astype(np.float32)
            mu_map = torch.from_numpy(mu_map)
            mu_map = mu_map.unsqueeze(dim=0)
    
            recon = recon.astype(np.float32)
            recon = torch.from_numpy(recon)
            recon = recon.unsqueeze(dim=0)
    
            recon, mu_map = self.transform_normalization(recon, mu_map)
    
            self.mu_maps[_id] = mu_map
            self.reconstructions[_id] = recon
    
        def pre_load_images(self):
    
            """
            Load all images into memory.
            """
    
            for _id in self.table[headers.id]:
                self.load_image(_id)
    
    
        def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
            """
            Get a reconstruction mu map pair by index.
    
            This method retrieves the image id of this index and call `get_item_by_id`.
    
            Parameters
            ----------
            index: int
            """
    
            row = self.table.iloc[index]
    
            _id = row[headers.id]
    
            return self.get_item_by_id(_id)
    
        def get_item_by_id(self, _id: int) -> Tuple[torch.Tensor, torch.Tensor]:
            """
            Get a reconstruction and mu map pair by their id.
    
            This methods loads the images of not yet in memory and applies the
            augmentation transform before returning them.
    
            Parameters
            ----------
            _id: int
    
            Returns
            -------
            Tuple[torch.Tensor, torch.Tensor]
                a pair of a reconstruction and the according mu map
            """
    
            if _id not in self.reconstructions:
                self.load_image(_id)
    
    
            recon = self.reconstructions[_id]
            mu_map = self.mu_maps[_id]
    
            recon, mu_map = self.transform_augmentation(recon, mu_map)
    
    
            return recon, mu_map
    
        def __len__(self) -> int:
            """
            Get the number of elements in this dataset.
            """
    
            return len(self.table)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    __all__ = [MuMapDataset.__name__]
    
    def main(dataset: MuMapDataset, ids: Optional[List[int]] = None, paused: bool=False):
        """
        Display reconstructions and mu maps in a dataset.
    
        Parameters
        ----------
        dataset: MuMapDataset
            the dataset of which elements are displayed
        ids: list of int, optional
            only display these ids
        paused: bool
            start display in paused mode 
        """
    
        from mu_map.util import to_grayscale, COLOR_WHITE
    
        wname = "Dataset"
    
        cv.namedWindow(wname, cv.WINDOW_NORMAL)
    
        cv.resizeWindow(wname, 1600, 900)
        space = np.full((1024, 10), 239, np.uint8)
    
        timeout_paused = 0
        timeout_running = 1000 // 15
    
        timeout = timeout_paused if paused else timeout_running
    
        def to_display_image(image, _slice):
            _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
            _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
            _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
    
            _image = cv.putText(
                _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
            )
    
            return _image
    
        def combine_images(images, slices):
            image_1 = to_display_image(images[0], slices[0])
            image_2 = to_display_image(images[1], slices[1])
    
    
            image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
            image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))
    
            image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
            image_3_1 = image_2.copy()
            image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)
    
            space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
            return np.hstack((image_1, space, image_3, space, image_2))
    
    
        for i in range(len(dataset)):
            ir = 0
            im = 0
    
    
            row = dataset.table.iloc[i]
            _id = row[headers.id]
    
    
            if ids is not None and _id not in ids:
                continue
    
    
            recon, mu_map = dataset[i]
    
            recon = recon.squeeze().numpy()
            mu_map = mu_map.squeeze().numpy()
    
            print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r")
    
    
            cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
            key = cv.waitKey(timeout)
    
            while True:
                ir = (ir + 1) % recon.shape[0]
                im = (im + 1) % mu_map.shape[0]
    
    
                to_show = combine_images((recon, mu_map), (ir, im))
                cv.imshow(wname, to_show)
    
                key = cv.waitKey(timeout)
    
    
                if key == ord("n"):
                    break
    
                elif key == ord("q"):
    
                elif key == ord("p"):
    
                    timeout = timeout_paused if timeout > 0 else timeout_running
    
                elif key == 82:  # up arrow key
                    ir = ir - 1
                    continue
    
                elif key == 83:  # right arrow key
    
                elif key == 81:  # left arrow key
    
                    ir = max(ir - 2, 0)
    
                elif key == 84:  # down arrow key
                    ir = ir - 1
    
                    im = max(im - 2, 0)
    
                elif key == ord("s"):
                    cv.imwrite(f"{running:03d}.png", to_show)
                    running += 1
    
        from mu_map.dataset.transform import PadCropTranform
    
        from mu_map.logging import add_logging_args, get_logger_by_args
    
        parser = argparse.ArgumentParser(
            description="Visualize the images of a MuMapDataset",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "dataset_dir", type=str, help="the directory from which the dataset is loaded"
        )
    
        parser.add_argument(
            "--split",
            type=str,
            choices=["train", "validation", "test"],
            help="choose the split of the data for the dataset",
        )
    
        parser.add_argument(
            "--unaligned",
            action="store_true",
            help="do not perform center alignment of reconstruction an mu-map slices",
        )
        parser.add_argument(
            "--show_bed",
            action="store_true",
            help="do not remove the bed contour from the mu map",
        )
        parser.add_argument(
            "--full_mu_map",
            action="store_true",
            help="do not remove broken slices of the mu map",
        )
    
        parser.add_argument(
            "--ids",
            type=int,
            nargs="*",
            help="only display certain ids",
        )
        parser.add_argument(
            "--paused",
            action="store_true",
            help="start in paused mode",
        )
        parser.add_argument(
            "--pad_crop",
            type=int,
            help="pad crop images to this size",
        )
    
        add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
        args = parser.parse_args()
    
        align = not args.unaligned
        discard_mu_map_slices = not args.full_mu_map
        bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
        logger = get_logger_by_args(args)
    
    
        transform_normalization = (
            PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
        )
    
    
        dataset = MuMapDataset(
            args.dataset_dir,
            align=align,
            discard_mu_map_slices=discard_mu_map_slices,
            bed_contours_file=bed_contours_file,
    
            split_name=args.split,
    
            transform_normalization=transform_normalization,
    
        main(dataset, args.ids, paused=args.paused)