Skip to content
Snippets Groups Projects
prepare.py 21.7 KiB
Newer Older
  • Learn to ignore specific revisions
  • import argparse
    from datetime import datetime, timedelta
    from enum import Enum
    import os
    
    
    import numpy as np
    import pandas as pd
    import pydicom
    
    
    from mu_map.data.dicom_util import DICOMTime, parse_age
    
    from mu_map.logging import add_logging_args, get_logger_by_args
    
    
    STUDY_DESCRIPTION = "µ-map_study"
    
    
    
    headers = argparse.Namespace()
    headers.id = "id"
    headers.patient_id = "patient_id"
    headers.age = "age"
    headers.weight = "weight"
    headers.size = "size"
    headers.protocol = "protocol"
    headers.datetime_acquisition = "datetime_acquisition"
    headers.datetime_reconstruction = "datetime_reconstruction"
    headers.pixel_spacing_x = "pixel_spacing_x"
    headers.pixel_spacing_y = "pixel_spacing_y"
    headers.pixel_spacing_z = "pixel_spacing_z"
    headers.shape_x = "shape_x"
    headers.shape_y = "shape_y"
    headers.shape_z = "shape_z"
    headers.radiopharmaceutical = "radiopharmaceutical"
    headers.radionuclide_dose = "radionuclide_dose"
    headers.radionuclide_code = "radionuclide_code"
    headers.radionuclide_meaning = "radionuclide_meaning"
    headers.energy_window_peak_lower = "energy_window_peak_lower"
    headers.energy_window_peak_upper = "energy_window_peak_upper"
    headers.energy_window_scatter_lower = "energy_window_scatter_lower"
    headers.energy_window_scatter_upper = "energy_window_scatter_upper"
    headers.detector_count = "detector_count"
    headers.collimator_type = "collimator_type"
    headers.rotation_start = "rotation_start"
    headers.rotation_step = "rotation_step"
    headers.rotation_scan_arc = "rotation_scan_arc"
    
    headers.file_projection = "file_projection"
    
    headers.file_recon_ac_sc = "file_recon_ac_sc"
    headers.file_recon_nac_sc = "file_recon_nac_sc"
    headers.file_recon_ac_nsc = "file_recon_ac_nsc"
    headers.file_recon_nac_nsc = "file_recon_nac_nsc"
    
    headers.file_mu_map = "file_mu_map"
    
    
    
    def get_protocol(projection: pydicom.dataset.FileDataset) -> str:
    
        Get the protocol (stress, rest) of a projection image by checking if
        it is part of the series description.
    
        :param projection: pydicom image of the projection
        :returns: the protocol as a string (Stress or Rest)
    
        if "stress" in projection.SeriesDescription.lower():
            return "Stress"
    
        if "rest" in projection.SeriesDescription.lower():
            return "Rest"
    
        raise ValueError(f"Unkown protocol in projection {projection.SeriesDescription}")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    def find_projections(
        dicom_images: List[pydicom.dataset.FileDataset],
    ) -> pydicom.dataset.FileDataset:
    
        Find all projections in a list of DICOM images belonging to a study.
    
    
        :param dicom_images: list of DICOM images of a study
        :return: the extracted DICOM image
        """
    
        _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images)
    
        dicom_images = []
        for dicom_image in _filter:
            # filter for allows series descriptions
            if dicom_image.SeriesDescription not in ["Stress", "Rest", "Stress_2"]:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                logger.warning(
                    f"Skip projection with unkown protocol [{dicom_image.SeriesDescription}]"
                )
    
                # print(f"   - {DICOMTime.Study.to_datetime(di)}, {DICOMTime.Series.to_datetime(di)}, {DICOMTime.Content.to_datetime(di)}, {DICOMTime.Acquisition.to_datetime(di)}")
                continue
            dicom_images.append(dicom_image)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    def is_recon_type(
        scatter_corrected: bool, attenuation_corrected: bool
    ) -> Callable[[pydicom.dataset.FileDataset], bool]:
    
        Get a filter function for reconstructions that are (non-)scatter and/or (non-)attenuation corrected.
    
        :param scatter_corrected: if the filter should only return true for scatter corrected reconstructions
        :param attenuation_corrected: if the filter should only return true for attenuation corrected reconstructions
        :returns: a filter function
    
        if scatter_corrected and attenuation_corrected:
            filter_str = " SC - AC "
        elif not scatter_corrected and attenuation_corrected:
            filter_str = " NoSC - AC "
        elif scatter_corrected and not attenuation_corrected:
            filter_str = " SC  - NoAC "
        elif not scatter_corrected and not attenuation_corrected:
            filter_str = " NoSC - NoAC "
    
        return lambda dicom: filter_str in dicom.SeriesDescription
    
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    def find_reconstruction(
        dicom_images: List[pydicom.dataset.FileDataset],
        projection: pydicom.dataset.FileDataset,
        scatter_corrected: bool,
        attenuation_corrected: bool,
    ) -> List[pydicom.dataset.FileDataset]:
    
        """
        Find a reconstruction in a list of dicom images of a study belonging to a projection.
    
        :param dicom_images: the list of dicom images belonging to the study
        :param projection: the dicom image of the projection
        :param scatter_corrected: if it should be searched fo a scatter corrected reconstruction
        :param attenuation_corrected: if it should be searched fo a attenuation corrected reconstruction
        :returns: the according reconstruction
        """
        protocol = get_protocol(projection)
    
        _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
        _filter = filter(lambda image: protocol in image.SeriesDescription, _filter)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        _filter = filter(
            lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
        )
        _filter = filter(
            lambda image: "CT" not in image.SeriesDescription, _filter
        )  # remove µ-maps
    
        _filter = filter(is_recon_type(scatter_corrected, attenuation_corrected), _filter)
    
        # _filter = list(_filter)
        # print("DEBUG Reconstructions: ")
        # for r in _filter:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        # try:
        # print(f"   - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}")
        # except Exception as e:
        # print(f"Error {e}")
    
        _filter = filter(
            lambda image: DICOMTime.Acquisition.to_datetime(image)
            == DICOMTime.Acquisition.to_datetime(projection),
            _filter,
        )
    
        dicom_images = list(_filter)
    
        if len(dicom_images) == 0:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            raise ValueError(
                f"No reconstruction with SC={scatter_corrected}, AC={attenuation_corrected} available"
            )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            logger.warning(
                f"Multiple reconstructions ({len(dicom_images)}) with SC={scatter_corrected}, AC={attenuation_corrected} for projection {projection.SeriesDescription} of patient {projection.PatientID}"
            )
            dicom_images.sort(
                key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True
            )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        dicom_images: List[pydicom.dataset.FileDataset],
        projection: pydicom.dataset.FileDataset,
        reconstructions: List[pydicom.dataset.FileDataset],
    
    ) -> pydicom.dataset.FileDataset:
    
        Find a reconstruction in a list of dicom images of a study belonging to a projection and reconstructions.
    
        :param dicom_images: the list of dicom images belonging to the study
        :param projection: the dicom image of the projection
        :param reconstructions: dicom images of reconstructions belonging to the projection
        :returns: the according attenuation map
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        recon_times = list(
            map(lambda recon: DICOMTime.Series.to_datetime(recon), reconstructions)
        )
    
        _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    
        _filter = filter(lambda image: protocol in image.SeriesDescription, _filter)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        _filter = filter(
            lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
        )
    
        _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        _filter = filter(
            lambda image: any(
                map(
                    lambda recon_time: (
                        DICOMTime.Series.to_datetime(image) - recon_time
                    ).seconds
                    < MAX_TIME_DIFF_S,
                    recon_times,
                )
            ),
            _filter,
        )
    
        if len(dicom_images) == 0:
    
            raise ValueError(f"No Attenuation map available")
    
        # sort oldest to be first
        if len(dicom_images) > 1:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            logger.warning(
                f"Multiple attenuation maps ({len(dicom_images)}) for projection {projection.SeriesDescription} of patient {projection.PatientID}"
            )
            dicom_images.sort(
                key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True
            )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    def get_relevant_images(
        patient: pydicom.dataset.FileDataset, dicom_dir: str
    ) -> List[pydicom.dataset.FileDataset]:
    
        """
        Get all relevant images of a patient.
    
        :param patient: pydicom dataset of a patient
        :param dicom_dir: the directory of the DICOM files
        :return: all relevant dicom images
        """
        # get all myocardial scintigraphy studies
        studies = list(
            filter(
                lambda child: child.DirectoryRecordType == "STUDY",
                # and child.StudyDescription == "Myokardszintigraphie", # filter is disabled because there is a study without this description and only such studies are exported anyway
                patient.children,
    
        # extract all dicom images
        dicom_images = []
        for study in studies:
            series = list(
                filter(
                    lambda child: child.DirectoryRecordType == "SERIES",
                    study.children,
                )
            )
            for _series in series:
                images = list(
                    filter(
                        lambda child: child.DirectoryRecordType == "IMAGE",
                        _series.children,
                    )
                )
    
                # all SPECT data is stored as a single 3D array which means that it is a series with a single image
                # this is not the case for CTs, which are skipped here
                if len(images) != 1:
                    continue
    
                images = list(
                    map(
                        lambda image: pydicom.dcmread(
                            os.path.join(
                                dicom_dir_by_patient[patient.PatientID],
                                *image.ReferencedFileID,
                            ),
                            stop_before_pixels=True,
                        ),
                        images,
                    )
                )
    
                if len(images) == 0:
                    continue
    
                dicom_images.append(images[0])
    
    
    
    if __name__ == "__main__":
        parser = argparse.ArgumentParser(
            description="Prepare a dataset from DICOM directories",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
        parser.add_argument(
            "dicom_dirs",
            type=str,
            nargs="+",
            help="paths to DICOMDIR files or directories containing one of them",
        )
    
        parser.add_argument(
            "--dataset_dir",
            type=str,
            required=True,
            help="directory where images, meta-information and the logs are stored",
        )
        parser.add_argument(
            "--images_dir",
            type=str,
            default="images",
            help="sub-directory of --dataset_dir where images are stored",
        )
        parser.add_argument(
            "--meta_csv",
            type=str,
            default="meta.csv",
            help="CSV file under --dataset_dir where meta-information is stored",
        )
        add_logging_args(
            parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
        )
    
        args = parser.parse_args()
    
        args.dicom_dirs = [
            (os.path.dirname(_file) if os.path.isfile(_file) else _file)
            for _file in args.dicom_dirs
        ]
        args.images_dir = os.path.join(args.dataset_dir, args.images_dir)
    
        args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv)
        args.logfile = os.path.join(args.dataset_dir, args.logfile)
    
        if not os.path.exists(args.dataset_dir):
            os.mkdir(args.dataset_dir)
    
        if not os.path.exists(args.images_dir):
            os.mkdir(args.images_dir)
    
        global logger
        logger = get_logger_by_args(args)
    
        try:
            patients = []
            dicom_dir_by_patient: Dict[str, str] = {}
            for dicom_dir in args.dicom_dirs:
                dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
                for patient in dataset.patient_records:
                    assert (
                        patient.PatientID not in dicom_dir_by_patient
                    ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})"
                    dicom_dir_by_patient[patient.PatientID] = dicom_dir
                    patients.append(patient)
    
            _id = 1
            if os.path.exists(args.meta_csv):
                data = pd.read_csv(args.meta_csv)
                _id = int(data[headers.id].max())
            else:
                data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))
    
            for i, patient in enumerate(patients, start=1):
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                logger.debug(
                    f"Process patient {str(i):>3}/{len(patients)} - {patient.PatientName.given_name}, {patient.PatientName.family_name}:"
                )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                dicom_images = get_relevant_images(
                    patient, dicom_dir_by_patient[patient.PatientID]
                )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                logger.debug(
                    f"- Found {len(projections)}: projections with protocols {list(map(lambda p: get_protocol(p), projections))}"
                )
    
    
                for projection in projections:
                    protocol = get_protocol(projection)
    
                    reconstructions = []
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    recon_headers = [
                        headers.file_recon_nac_nsc,
                        headers.file_recon_nac_sc,
                        headers.file_recon_ac_nsc,
                        headers.file_recon_ac_sc,
                    ]
                    recon_postfixes = [
                        "recon_nac_nsc",
                        "recon_nac_sc",
                        "recon_ac_nsc",
                        "recon_ac_sc",
                    ]
                    for ac, sc in [
                        (False, False),
                        (False, True),
                        (True, False),
                        (True, True),
                    ]:
                        recon = find_reconstruction(
                            dicom_images,
                            projection,
                            scatter_corrected=sc,
                            attenuation_corrected=ac,
                        )
    
                        reconstructions.append(recon)
                    mu_map = find_attenuation_map(dicom_images, projection, reconstructions)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    # for i, projection in enumerate(projections):
    
                    # _time = DICOMTime.Series.to_datetime(projection)
                    # print(f" - Projection: {projection.SeriesDescription:>10} at                               {DICOMTime.Study.to_datetime(projection)}, {DICOMTime.Series.to_datetime(projection)}, {DICOMTime.Content.to_datetime(projection)}, {DICOMTime.Acquisition.to_datetime(projection)}")
                    # recons = []
                    # for sc, ac in [(False, False), (False, True), (True, False), (True, True)]:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    # r = find_reconstruction(dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac)
                    # print(f"   - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}")
                    # recons.append(r)
    
                    # mu_map = find_attenuation_map(dicom_images, projection, recons)
                    # print(f"   - {mu_map.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(mu_map)}, {DICOMTime.Series.to_datetime(mu_map)}, {DICOMTime.Content.to_datetime(mu_map)}")
                    # print(f" -")
    
                    # extract pixel spacings and assert that they are equal for all reconstruction images
                    _map_lists = map(
                        lambda image: [*image.PixelSpacing, image.SliceThickness],
                        [*reconstructions, mu_map],
    
                    _map_lists = map(
                        lambda pixel_spacing: list(map(float, pixel_spacing)),
                        _map_lists,
    
                    _map_ndarrays = map(
                        lambda pixel_spacing: np.array(pixel_spacing), _map_lists
                    )
                    pixel_spacings = list(_map_ndarrays)
                    _equal = all(
                        map(
                            lambda pixel_spacing: (
                                pixel_spacing == pixel_spacings[0]
                            ).all(),
                            pixel_spacings,
    
                    )
                    assert (
                        _equal
                    ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
                    pixel_spacing = pixel_spacings[0]
    
                    # use the shape with the fewest slices, all other images will be aligned to that
                    _map_lists = map(
                        lambda image: [
                            image.Rows,
                            image.Columns,
                            image.NumberOfSlices,
                        ],
                        [*reconstructions, mu_map],
                    )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    _map_lists = map(lambda shape: list(map(int, shape)), _map_lists)
    
                    _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
                    shapes = list(_map_ndarrays)
                    shapes.sort(key=lambda shape: shape[2])
                    shape = shapes[0]
    
                    # extract and sort energy windows
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    energy_windows = projection.EnergyWindowInformationSequence
    
                    energy_windows = map(
                        lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
                    )
                    energy_windows = map(
                        lambda ew: (
                            float(ew.EnergyWindowLowerLimit),
                            float(ew.EnergyWindowUpperLimit),
                        ),
                        energy_windows,
                    )
                    energy_windows = list(energy_windows)
                    energy_windows.sort(key=lambda ew: ew[0], reverse=True)
    
                    row = {
                        headers.id: _id,
                        headers.patient_id: projection.PatientID,
                        headers.age: parse_age(projection.PatientAge),
                        headers.weight: float(projection.PatientWeight),
                        headers.size: float(projection.PatientSize),
                        headers.protocol: protocol,
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                        headers.datetime_acquisition: DICOMTime.Series.to_datetime(
                            projection
                        ),
                        headers.datetime_reconstruction: DICOMTime.Series.to_datetime(
                            reconstructions[0]
                        ),
    
                        headers.pixel_spacing_x: pixel_spacing[0],
                        headers.pixel_spacing_y: pixel_spacing[1],
                        headers.pixel_spacing_z: pixel_spacing[2],
                        headers.shape_x: shape[0],
                        headers.shape_y: shape[1],
                        headers.shape_z: shape[2],
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                        headers.radiopharmaceutical: projection.RadiopharmaceuticalInformationSequence[
                            0
                        ].Radiopharmaceutical,
                        headers.radionuclide_dose: projection.RadiopharmaceuticalInformationSequence[
                            0
                        ].RadionuclideTotalDose,
                        headers.radionuclide_code: projection.RadiopharmaceuticalInformationSequence[
                            0
                        ]
                        .RadionuclideCodeSequence[0]
                        .CodeValue,
                        headers.radionuclide_meaning: projection.RadiopharmaceuticalInformationSequence[
                            0
                        ]
                        .RadionuclideCodeSequence[0]
                        .CodeMeaning,
    
                        headers.energy_window_peak_lower: energy_windows[0][0],
                        headers.energy_window_peak_upper: energy_windows[0][1],
                        headers.energy_window_scatter_lower: energy_windows[1][0],
                        headers.energy_window_scatter_upper: energy_windows[1][1],
                        headers.detector_count: len(projection.DetectorInformationSequence),
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                        headers.collimator_type: projection.DetectorInformationSequence[
                            0
                        ].CollimatorType,
                        headers.rotation_start: float(
                            projection.RotationInformationSequence[0].StartAngle
                        ),
                        headers.rotation_step: float(
                            projection.RotationInformationSequence[0].AngularStep
                        ),
                        headers.rotation_scan_arc: float(
                            projection.RotationInformationSequence[0].ScanArc
                        ),
    
                    }
    
                    _filename_base = f"{_id:04d}-{protocol.lower()}"
                    _ext = "dcm"
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    img_headers = [
                        headers.file_projection,
                        *recon_headers,
                        headers.file_mu_map,
                    ]
    
                    img_postfixes = ["projection", *recon_postfixes, "mu_map"]
                    images = [projection, *reconstructions, mu_map]
                    for header, image, postfix in zip(img_headers, images, img_postfixes):
                        _image = pydicom.dcmread(image.filename)
                        filename = f"{_filename_base}-{postfix}.{_ext}"
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                        pydicom.dcmwrite(os.path.join(args.images_dir, filename), _image)
    
                    _id += 1
                    row = pd.DataFrame(row, index=[0])
                    data = pd.concat((data, row), ignore_index=True)
    
    
            data.to_csv(args.meta_csv, index=False)
        except Exception as e:
            logger.error(e)
            raise e