Skip to content
Snippets Groups Projects
default.py 9.09 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    import os
    
    from typing import Optional, Tuple
    
    import cv2 as cv
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    import pandas as pd
    import pydicom
    
    import numpy as np
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    from torch.utils.data import Dataset
    
    
    from mu_map.data.prepare import headers
    
    from mu_map.data.remove_bed import (
        DEFAULT_BED_CONTOURS_FILENAME,
        load_contours,
        remove_bed,
    )
    
    from mu_map.data.review_mu_map import discard_slices
    
    from mu_map.data.split import split_csv
    
    from mu_map.dataset.transform import Transform
    
    from mu_map.dataset.util import align_images, load_dcm_img
    
    from mu_map.logging import get_logger
    
    class MuMapDataset(Dataset):
        def __init__(
    
            self,
            dataset_dir: str,
            csv_file: str = "meta.csv",
    
            split_file: str = "split.csv",
            split_name: str = None,
    
            images_dir: str = "images",
    
            bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
    
            discard_mu_map_slices: bool = True,
    
            align: bool = True,
    
            scatter_correction: bool = False,
    
            transform_normalization: Transform = Transform(),
            transform_augmentation: Transform = Transform(),
            logger=None,
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            super().__init__()
    
    
            self.dir = dataset_dir
            self.dir_images = os.path.join(dataset_dir, images_dir)
            self.csv_file = os.path.join(dataset_dir, csv_file)
    
            self.split_file = os.path.join(dataset_dir, split_file)
    
    
            self.transform_normalization = transform_normalization
            self.transform_augmentation = transform_augmentation
            self.logger = logger if logger is not None else get_logger()
    
            self.bed_contours_file = (
                os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
            )
            self.bed_contours = (
                load_contours(self.bed_contours_file) if bed_contours_file else None
            )
    
            # read CSV file and from that access DICOM files
    
            self.table = pd.read_csv(self.csv_file)
    
            if split_name:
                self.table = split_csv(self.table, self.split_file)[split_name]
    
            self.table[headers.id] = self.table[headers.id].apply(int)
    
            self.discard_mu_map_slices = discard_mu_map_slices
    
            self.align = align
    
            self.scatter_correction = scatter_correction
    
            self.header_recon = (
                headers.file_recon_nac_sc
                if self.scatter_correction
                else headers.file_recon_nac_nsc
            )
    
            self.reconstructions = {}
            self.mu_maps = {}
    
    
        def load_image(self, _id: int):
            row = self.table[self.table[headers.id] == _id].iloc[0]
            _id = row[headers.id]
    
            mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
            mu_map = load_dcm_img(mu_map_file)
            if self.discard_mu_map_slices:
                mu_map = discard_slices(row, mu_map)
            if self.bed_contours:
                if _id in self.bed_contours:
                    bed_contour = self.bed_contours[_id]
                    mu_map = remove_bed(mu_map, bed_contour)
                else:
                    logger.warning(f"Could not find bed contour for id {_id}")
    
            recon_file = os.path.join(self.dir_images, row[self.header_recon])
            recon = load_dcm_img(recon_file)
            if self.align:
                recon, mu_map = align_images(recon, mu_map)
    
            mu_map = mu_map.astype(np.float32)
            mu_map = torch.from_numpy(mu_map)
            mu_map = mu_map.unsqueeze(dim=0)
    
            recon = recon.astype(np.float32)
            recon = torch.from_numpy(recon)
            recon = recon.unsqueeze(dim=0)
    
            recon, mu_map = self.transform_normalization(recon, mu_map)
    
            self.mu_maps[_id] = mu_map
            self.reconstructions[_id] = recon
    
        def __getitem__(self, index: int):
            row = self.table.iloc[index]
    
            _id = row[headers.id]
            return self.getitem_by_id(_id)
    
        def getitem_by_id(self, _id: int):
    
            if _id not in self.reconstructions:
                self.load_image(_id)
    
    
            recon = self.reconstructions[_id]
            mu_map = self.mu_maps[_id]
    
            recon, mu_map = self.transform_augmentation(recon, mu_map)
    
    
            return recon, mu_map
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
        def __len__(self):
    
            return len(self.table)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    __all__ = [MuMapDataset.__name__]
    
    def main(dataset, ids, paused=False):
    
        from mu_map.util import to_grayscale, COLOR_WHITE
    
        wname = "Dataset"
    
        cv.namedWindow(wname, cv.WINDOW_NORMAL)
    
        cv.resizeWindow(wname, 1600, 900)
        space = np.full((1024, 10), 239, np.uint8)
    
        TIMEOUT_PAUSED = 0
        TIMEOUT_RUNNING = 1000 // 15
    
        timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING
    
        def to_display_image(image, _slice):
            _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
            _image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
            _text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
    
            _image = cv.putText(
                _image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
            )
    
            return _image
    
        def combine_images(images, slices):
            image_1 = to_display_image(images[0], slices[0])
            image_2 = to_display_image(images[1], slices[1])
    
    
            image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
            image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))
    
            image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
            image_3_1 = image_2.copy()
            image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)
    
            space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
            return np.hstack((image_1, space, image_3, space, image_2))
    
    
        for i in range(len(dataset)):
            ir = 0
            im = 0
    
    
            row = dataset.table.iloc[i]
            _id = row[headers.id]
    
    
            if ids is not None and _id not in ids:
                continue
    
    
            recon, mu_map = dataset[i]
    
            recon = recon.squeeze().numpy()
            mu_map = mu_map.squeeze().numpy()
    
            print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r")
    
    
            cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
            key = cv.waitKey(timeout)
    
            while True:
                ir = (ir + 1) % recon.shape[0]
                im = (im + 1) % mu_map.shape[0]
    
    
                to_show = combine_images((recon, mu_map), (ir, im))
                cv.imshow(wname, to_show)
    
                key = cv.waitKey(timeout)
    
    
                if key == ord("n"):
                    break
    
                elif key == ord("q"):
    
                elif key == ord("p"):
    
                    timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING
                elif key == 82:  # up arrow key
                    ir = ir - 1
                    continue
    
                elif key == 83:  # right arrow key
    
                elif key == 81:  # left arrow key
    
                    ir = max(ir - 2, 0)
    
                elif key == 84:  # down arrow key
                    ir = ir - 1
    
                    im = max(im - 2, 0)
    
                elif key == ord("s"):
                    cv.imwrite(f"{running:03d}.png", to_show)
                    running += 1
    
        from mu_map.dataset.transform import PadCropTranform
    
        from mu_map.logging import add_logging_args, get_logger_by_args
    
        parser = argparse.ArgumentParser(
            description="Visualize the images of a MuMapDataset",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "dataset_dir", type=str, help="the directory from which the dataset is loaded"
        )
    
        parser.add_argument(
            "--split",
            type=str,
            choices=["train", "validation", "test"],
            help="choose the split of the data for the dataset",
        )
    
        parser.add_argument(
            "--unaligned",
            action="store_true",
            help="do not perform center alignment of reconstruction an mu-map slices",
        )
        parser.add_argument(
            "--show_bed",
            action="store_true",
            help="do not remove the bed contour from the mu map",
        )
        parser.add_argument(
            "--full_mu_map",
            action="store_true",
            help="do not remove broken slices of the mu map",
        )
    
        parser.add_argument(
            "--ids",
            type=int,
            nargs="*",
            help="only display certain ids",
        )
        parser.add_argument(
            "--paused",
            action="store_true",
            help="start in paused mode",
        )
        parser.add_argument(
            "--pad_crop",
            type=int,
            help="pad crop images to this size",
        )
    
        add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
        args = parser.parse_args()
    
        align = not args.unaligned
        discard_mu_map_slices = not args.full_mu_map
        bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
        logger = get_logger_by_args(args)
    
    
        transform_normalization = (
            PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
        )
    
    
        dataset = MuMapDataset(
            args.dataset_dir,
            align=align,
            discard_mu_map_slices=discard_mu_map_slices,
            bed_contours_file=bed_contours_file,
    
            split_name=args.split,
    
            transform_normalization=transform_normalization,
    
        main(dataset, args.ids, paused=args.paused)