Skip to content
Snippets Groups Projects
prepare.py 24.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • import argparse
    from datetime import datetime, timedelta
    from enum import Enum
    import os
    
    from typing import List, Dict
    
    
    import numpy as np
    import pandas as pd
    import pydicom
    
    
    from mu_map.logging import add_logging_args, get_logger_by_args
    
    
    STUDY_DESCRIPTION = "µ-map_study"
    
    
    
    class MyocardialProtocol(Enum):
        Stress = 1
        Rest = 2
    
    
    headers = argparse.Namespace()
    headers.id = "id"
    headers.patient_id = "patient_id"
    headers.age = "age"
    headers.weight = "weight"
    headers.size = "size"
    headers.protocol = "protocol"
    headers.datetime_acquisition = "datetime_acquisition"
    headers.datetime_reconstruction = "datetime_reconstruction"
    headers.pixel_spacing_x = "pixel_spacing_x"
    headers.pixel_spacing_y = "pixel_spacing_y"
    headers.pixel_spacing_z = "pixel_spacing_z"
    headers.shape_x = "shape_x"
    headers.shape_y = "shape_y"
    headers.shape_z = "shape_z"
    headers.radiopharmaceutical = "radiopharmaceutical"
    headers.radionuclide_dose = "radionuclide_dose"
    headers.radionuclide_code = "radionuclide_code"
    headers.radionuclide_meaning = "radionuclide_meaning"
    headers.energy_window_peak_lower = "energy_window_peak_lower"
    headers.energy_window_peak_upper = "energy_window_peak_upper"
    headers.energy_window_scatter_lower = "energy_window_scatter_lower"
    headers.energy_window_scatter_upper = "energy_window_scatter_upper"
    headers.detector_count = "detector_count"
    headers.collimator_type = "collimator_type"
    headers.rotation_start = "rotation_start"
    headers.rotation_step = "rotation_step"
    headers.rotation_scan_arc = "rotation_scan_arc"
    
    headers.file_projection = "file_projection"
    
    headers.file_recon_ac_sc = "file_recon_ac_sc"
    headers.file_recon_nac_sc = "file_recon_nac_sc"
    headers.file_recon_ac_nsc = "file_recon_ac_nsc"
    headers.file_recon_nac_nsc = "file_recon_nac_nsc"
    
    headers.file_mu_map = "file_mu_map"
    
    
    def parse_series_time(dicom_image: pydicom.dataset.FileDataset) -> datetime:
        """
        Parse the date and time of a DICOM series object into a datetime object.
    
        :param dicom_image: the dicom file to parse the series date and time from
        :return: an according python datetime object.
        """
        _date = dicom_image.SeriesDate
        _time = dicom_image.SeriesTime
        return datetime(
            year=int(_date[0:4]),
            month=int(_date[4:6]),
            day=int(_date[6:8]),
            hour=int(_time[0:2]),
            minute=int(_time[2:4]),
            second=int(_time[4:6]),
            microsecond=int(_time.split(".")[1]),
        )
    
    
    def parse_age(patient_age: str) -> int:
        """
        Parse and age string as defined in the DICOM standard into an integer representing the age in years.
    
        :param patient_age: age string as defined in the DICOM standard
        :return: the age in years as a number
        """
        assert (
            type(patient_age) == str
        ), f"patient age needs to be a string and not {type(patient_age)}"
        assert (
            len(patient_age) == 4
        ), f"patient age [{patient_age}] has to be four characters long"
        _num, _format = patient_age[:3], patient_age[3]
        assert (
            _format == "Y"
        ), f"currently, only patient ages in years [Y] is supported, not [{_format}]"
        return int(_num)
    
    
    
        dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol
    ) -> pydicom.dataset.FileDataset:
    
        """
        Extract the SPECT projection from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
    
        :param dicom_images: list of DICOM images of a study
        :param protocol: the protocol for which the projection images should be extracted
        :return: the extracted DICOM image
        """
    
        _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images)
        _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter)
        dicom_images = list(_filter)
    
        if len(dicom_images) == 0:
            raise ValueError(f"No projection for protocol {protocol.name} available")
    
        dicom_images: List[pydicom.dataset.FileDataset],
        protocol: MyocardialProtocol,
    
    ) -> pydicom.dataset.FileDataset:
    
        """
        Extract a SPECT reconstruction from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
        The corrected flag can be used to either extract an attenuation corrected or a non-attenuation corrected image.
        If there are multiple images, they are sorted by acquisition date and the newest is returned.
    
        :param dicom_images: list of DICOM images of a study
        :param protocol: the protocol for which the projection images should be extracted
    
        :param attenuation_corrected: extract an attenuation or non-attenuation corrected image
        :param scatter_corrected: extract the image to which scatter correction was applied
    
        :return: the extracted DICOM image
        """
    
        _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    
        _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter)
    
        _filter = filter(
            lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
    
        if scatter_corrected and attenuation_corrected:
            filter_str = " SC - AC "
        elif not scatter_corrected and attenuation_corrected:
            filter_str = " NoSC - AC "
        elif scatter_corrected and not attenuation_corrected:
            filter_str = " SC  - NoAC "
        elif not scatter_corrected and not attenuation_corrected:
            filter_str = " NoSC - NoAC "
        _filter = filter(lambda image: filter_str in image.SeriesDescription, _filter)
    
    
        # for SPECT reconstructions created in clinical studies this value exists and is set to 'APEX_TO_BASE'
        # for the reconstructions with attenuation maps it does not exist
    
        # _filter = filter(
        # lambda image: not hasattr(image, "SliceProgressionDirection"), _filter
        # )
    
        dicom_images = list(_filter)
    
        if len(dicom_images) == 0:
    
            raise ValueError(
    
                f"'{filter_str}' Reconstruction for protocol {protocol.name} is not available"
    
        dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol
    ) -> pydicom.dataset.FileDataset:
    
        """
        Extract an attenuation map from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
        If there are multiple attenuation maps, they are sorted by acquisition date and the newest is returned.
    
        :param dicom_images: list of DICOM images of a study
        :param protocol: the protocol for which the projection images should be extracted
        :return: the extracted DICOM image
        """
    
        _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    
        _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter)
    
        _filter = filter(
    
            lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
    
        _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter)
    
        if len(dicom_images) == 0:
    
            raise ValueError(
                f"Attenuation map for protocol {protocol.name} is not available"
            )
    
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(
            description="Prepare a dataset from DICOM directories",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "dicom_dirs",
            type=str,
            nargs="+",
            help="paths to DICOMDIR files or directories containing one of them",
        )
    
        parser.add_argument(
            "--dataset_dir",
            type=str,
            required=True,
            help="directory where images, meta-information and the logs are stored",
        )
        parser.add_argument(
            "--images_dir",
            type=str,
            default="images",
            help="sub-directory of --dataset_dir where images are stored",
        )
        parser.add_argument(
            "--meta_csv",
            type=str,
            default="meta.csv",
            help="CSV file under --dataset_dir where meta-information is stored",
        )
        add_logging_args(
            parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
        )
    
        args = parser.parse_args()
    
        args.dicom_dirs = [
            (os.path.dirname(_file) if os.path.isfile(_file) else _file)
            for _file in args.dicom_dirs
        ]
        args.images_dir = os.path.join(args.dataset_dir, args.images_dir)
    
        args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv)
        args.logfile = os.path.join(args.dataset_dir, args.logfile)
    
        if not os.path.exists(args.dataset_dir):
            os.mkdir(args.dataset_dir)
    
        if not os.path.exists(args.images_dir):
            os.mkdir(args.images_dir)
    
        global logger
        logger = get_logger_by_args(args)
    
        try:
            patients = []
            dicom_dir_by_patient: Dict[str, str] = {}
            for dicom_dir in args.dicom_dirs:
                dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
                for patient in dataset.patient_records:
                    assert (
                        patient.PatientID not in dicom_dir_by_patient
                    ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})"
                    dicom_dir_by_patient[patient.PatientID] = dicom_dir
                    patients.append(patient)
    
            _id = 1
            if os.path.exists(args.meta_csv):
                data = pd.read_csv(args.meta_csv)
                _id = int(data[headers.id].max())
            else:
                data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))
    
            for i, patient in enumerate(patients, start=1):
                logger.debug(f"Process patient {str(i):>3}/{len(patients)}:")
    
                # get all myocardial scintigraphy studies
                studies = list(
    
                        lambda child: child.DirectoryRecordType == "STUDY",
    
                        # and child.StudyDescription == "Myokardszintigraphie", # filter is disabled because there is a study without this description and only such studies are exported anyway
    
    
                # extract all dicom images
                dicom_images = []
                for study in studies:
                    series = list(
    
                            lambda child: child.DirectoryRecordType == "SERIES",
                            study.children,
    
                    for _series in series:
                        images = list(
                            filter(
                                lambda child: child.DirectoryRecordType == "IMAGE",
                                _series.children,
                            )
                        )
    
                        # all SPECT data is stored as a single 3D array which means that it is a series with a single image
                        # this is not the case for CTs, which are skipped here
                        if len(images) != 1:
                            continue
    
                        images = list(
                            map(
                                lambda image: pydicom.dcmread(
                                    os.path.join(
                                        dicom_dir_by_patient[patient.PatientID],
                                        *image.ReferencedFileID,
                                    ),
                                    stop_before_pixels=True,
    
                        if len(images) == 0:
                            continue
    
                        dicom_images.append(images[0])
    
                    for protocol in MyocardialProtocol:
                        if (
                            len(
                                data[
    
                                    & (data[headers.protocol] == protocol.name)
                                ]
                            )
                            > 0
                        ):
                            logger.info(
                                f"Skip {patient.PatientID}:{protocol.name} since it is already contained in the dataset"
                            )
                            continue
    
    
                        extractions = [
                            {
                                "function": get_projections,
                                "kwargs": {},
                                "reconstruction": False,
                                "prefix": "projection",
                                "header": headers.file_projection,
                            },
                            {
                                "function": get_reconstructions,
                                "kwargs": {
                                    "attenuation_corrected": True,
                                    "scatter_corrected": True,
                                },
                                "reconstruction": True,
                                "prefix": "recon_ac_sc",
                                "header": headers.file_recon_ac_sc,
                            },
                            {
                                "function": get_reconstructions,
                                "kwargs": {
                                    "attenuation_corrected": True,
                                    "scatter_corrected": False,
                                },
                                "reconstruction": True,
                                "prefix": "recon_ac_nsc",
                                "header": headers.file_recon_ac_nsc,
                            },
                            {
                                "function": get_reconstructions,
                                "kwargs": {
                                    "attenuation_corrected": False,
                                    "scatter_corrected": True,
                                },
                                "reconstruction": True,
                                "prefix": "recon_nac_sc",
                                "header": headers.file_recon_nac_sc,
                            },
                            {
                                "function": get_reconstructions,
                                "kwargs": {
                                    "attenuation_corrected": False,
                                    "scatter_corrected": False,
                                },
                                "reconstruction": True,
                                "prefix": "recon_nac_nsc",
                                "header": headers.file_recon_nac_nsc,
                            },
                            {
                                "function": get_attenuation_maps,
                                "kwargs": {},
                                "reconstruction": True,
                                "prefix": "mu_map",
                                "header": headers.file_mu_map,
                            },
                        ]
    
    
                            for extrac in extractions:
                                _images = extrac["function"](
                                    dicom_images, protocol=protocol, **extrac["kwargs"]
                                )
                                _images.sort(key=parse_series_time)
                                extrac["images"] = _images
    
                        except ValueError as e:
    
                            logger.info(
                                f"Skip {patient.PatientID}:{protocol.name} because {e}"
                            )
    
                        num_images = min(
                            list(map(lambda extrac: len(extrac["images"]), extractions))
    
                        # ATTENTION: this is a special filter for mu maps which could have been saved from previous test runs of the workflow
                        # this filter only keeps the most recent ones
                        if num_images < len(extractions[-1]["images"]):
                            _len = len(extractions[-1]["images"])
                            extractions[-1]["images"] = extractions[-1]["images"][
                                (_len - num_images) :
    
    
                        for j in range(num_images):
                            _recon_images = filter(
                                lambda extrac: extrac["reconstruction"], extractions
                            )
                            recon_images = list(
                                map(lambda extrac: extrac["images"][j], _recon_images)
                            )
    
                            # extract date times and assert that they are equal for all reconstruction images
                            datetimes = list(map(parse_series_time, recon_images))
                            _datetimes = sorted(datetimes, reverse=True)
                            _datetimes_delta = list(
                                map(lambda dt: _datetimes[0] - dt, _datetimes)
                            )
                            # note: somehow the images receive slightly different timestamps, maybe this depends on time to save and computation time
                            # thus, a 10 minute interval is allowed here
                            _equal = all(
                                map(lambda dt: dt < timedelta(minutes=10), _datetimes_delta)
                            )
                            assert (
                                _equal
                            ), f"Not all dates and times of the reconstructions are equal: {datetimes}"
    
                            # extract pixel spacings and assert that they are equal for all reconstruction images
                            _map_lists = map(
                                lambda image: [*image.PixelSpacing, image.SliceThickness],
                                recon_images,
                            )
                            _map_lists = map(
                                lambda pixel_spacing: list(map(float, pixel_spacing)),
                                _map_lists,
                            )
                            _map_ndarrays = map(
                                lambda pixel_spacing: np.array(pixel_spacing), _map_lists
                            )
                            pixel_spacings = list(_map_ndarrays)
                            _equal = all(
                                map(
                                    lambda pixel_spacing: (
                                        pixel_spacing == pixel_spacings[0]
                                    ).all(),
                                    pixel_spacings,
                                )
                            )
                            assert (
                                _equal
                            ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
                            pixel_spacing = pixel_spacings[0]
    
                            # use the shape with the fewest slices, all other images will be aligned to that
                            _map_lists = map(
                                lambda image: [
                                    image.Rows,
                                    image.Columns,
                                    image.NumberOfSlices,
                                ],
                                recon_images,
                            )
                            _map_lists = map(
                                lambda shape: list(map(int, shape)), _map_lists
                            )
                            _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
                            shapes = list(_map_ndarrays)
                            shapes.sort(key=lambda shape: shape[2])
                            shape = shapes[0]
    
                            projection_image = extractions[0]["images"][j]
                            # extract and sort energy windows
                            energy_windows = (
                                projection_image.EnergyWindowInformationSequence
                            )
                            energy_windows = map(
                                lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
                            )
                            energy_windows = map(
                                lambda ew: (
                                    float(ew.EnergyWindowLowerLimit),
                                    float(ew.EnergyWindowUpperLimit),
                                ),
                                energy_windows,
                            )
                            energy_windows = list(energy_windows)
                            energy_windows.sort(key=lambda ew: ew[0], reverse=True)
    
                            row = {
                                headers.id: _id,
                                headers.patient_id: projection_image.PatientID,
                                headers.age: parse_age(projection_image.PatientAge),
                                headers.weight: float(projection_image.PatientWeight),
                                headers.size: float(projection_image.PatientSize),
                                headers.protocol: protocol.name,
                                headers.datetime_acquisition: parse_series_time(
                                    projection_image
                                ),
                                headers.datetime_reconstruction: datetimes[0],
                                headers.pixel_spacing_x: pixel_spacing[0],
                                headers.pixel_spacing_y: pixel_spacing[1],
                                headers.pixel_spacing_z: pixel_spacing[2],
                                headers.shape_x: shape[0],
                                headers.shape_y: shape[1],
                                headers.shape_z: shape[2],
                                headers.radiopharmaceutical: projection_image.RadiopharmaceuticalInformationSequence[
                                    0
                                ].Radiopharmaceutical,
                                headers.radionuclide_dose: projection_image.RadiopharmaceuticalInformationSequence[
                                    0
                                ].RadionuclideTotalDose,
                                headers.radionuclide_code: projection_image.RadiopharmaceuticalInformationSequence[
                                    0
                                ]
                                .RadionuclideCodeSequence[0]
                                .CodeValue,
                                headers.radionuclide_meaning: projection_image.RadiopharmaceuticalInformationSequence[
                                    0
                                ]
                                .RadionuclideCodeSequence[0]
                                .CodeMeaning,
                                headers.energy_window_peak_lower: energy_windows[0][0],
                                headers.energy_window_peak_upper: energy_windows[0][1],
                                headers.energy_window_scatter_lower: energy_windows[1][0],
                                headers.energy_window_scatter_upper: energy_windows[1][1],
                                headers.detector_count: len(
                                    projection_image.DetectorInformationSequence
                                ),
                                headers.collimator_type: projection_image.DetectorInformationSequence[
                                    0
                                ].CollimatorType,
                                headers.rotation_start: float(
                                    projection_image.RotationInformationSequence[
                                        0
                                    ].StartAngle
                                ),
                                headers.rotation_step: float(
                                    projection_image.RotationInformationSequence[
                                        0
                                    ].AngularStep
                                ),
                                headers.rotation_scan_arc: float(
                                    projection_image.RotationInformationSequence[0].ScanArc
                                ),
                            }
    
                            _filename_base = f"{_id:04d}-{protocol.name.lower()}"
                            _ext = "dcm"
                            _images = list(
                                map(lambda extrac: extrac["images"][j], extractions)
                            )
                            for _image, extrac in zip(_images, extractions):
                                image = pydicom.dcmread(_image.filename)
                                filename = f"{_filename_base}-{extrac['prefix']}.{_ext}"
                                pydicom.dcmwrite(
                                    os.path.join(args.images_dir, filename), image
                                )
                                row[extrac["header"]] = filename
    
                            _id += 1
                            row = pd.DataFrame(row, index=[0])
                            data = pd.concat((data, row), ignore_index=True)
    
    
            data.to_csv(args.meta_csv, index=False)
        except Exception as e:
            logger.error(e)
            raise e