Skip to content
Snippets Groups Projects
prepare.py 20.7 KiB
Newer Older
"""
Utility script to prepare raw data for further processing.
It is highly dependant on the way the raw DICOM data were exported.
The script creates a folder containing the file `meta.csv` and another
folder containing the DICOM images sorted by an id for each study.
"""
import argparse
from datetime import datetime, timedelta
from enum import Enum
import os

import numpy as np
import pandas as pd
import pydicom

from mu_map.file.dicom import DICOM, DICOMTime, parse_age
from mu_map.logging import add_logging_args, get_logger_by_args

STUDY_DESCRIPTION = "µ-map_study"


headers = argparse.Namespace()
headers.id = "id"
headers.patient_id = "patient_id"
headers.age = "age"
headers.sex = "sex"
headers.weight = "weight"
headers.size = "size"
headers.patient_position = "patient_position"
headers.protocol = "protocol"
headers.datetime_acquisition = "datetime_acquisition"
headers.datetime_reconstruction = "datetime_reconstruction"
headers.pixel_spacing_x = "pixel_spacing_x"
headers.pixel_spacing_y = "pixel_spacing_y"
headers.pixel_spacing_z = "pixel_spacing_z"
headers.shape_x = "shape_x"
headers.shape_y = "shape_y"
headers.shape_z = "shape_z"
headers.radiopharmaceutical = "radiopharmaceutical"
headers.radionuclide_dose = "radionuclide_dose"
headers.radionuclide_code = "radionuclide_code"
headers.radionuclide_meaning = "radionuclide_meaning"
headers.energy_window_peak_lower = "energy_window_peak_lower"
headers.energy_window_peak_upper = "energy_window_peak_upper"
headers.energy_window_scatter_lower = "energy_window_scatter_lower"
headers.energy_window_scatter_upper = "energy_window_scatter_upper"
headers.detector_count = "detector_count"
headers.collimator_type = "collimator_type"
headers.rotation_start = "rotation_start"
headers.rotation_step = "rotation_step"
headers.rotation_scan_arc = "rotation_scan_arc"
headers.file_projection = "file_projection"
headers.file_recon_ac_sc = "file_recon_ac_sc"
headers.file_recon_nac_sc = "file_recon_nac_sc"
headers.file_recon_ac_nsc = "file_recon_ac_nsc"
headers.file_recon_nac_nsc = "file_recon_nac_nsc"
headers.file_mu_map = "file_mu_map"


def get_protocol(projection: DICOM) -> str:
    Get the protocol (stress, rest) of a projection image by checking if
    it is part of the series description.
    Parameters
    ----------
    projection: DICOM
        DICOM image of the projection

    Returns
    -------
    str
        the protocol as a string (Stress or Rest)
    if "stress" in projection.SeriesDescription.lower():
        return "Stress"
    if "rest" in projection.SeriesDescription.lower():
        return "Rest"
    raise ValueError(f"Unkown protocol in projection {projection.SeriesDescription}")
Tamino Huxohl's avatar
Tamino Huxohl committed
def find_projections(
    dicom_images: List[DICOM],
) -> List[DICOM]:
    Find a projections in a list of DICOM images belonging to a study.

    Parameters
    ----------
     dicom_images: list of DICOM
        DICOM images of a study
    Returns
    -------
    list of DICOM
        all projection images in the input list
    _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images)
    dicom_images = []
    for dicom_image in _filter:
        # filter for allows series descriptions
        if dicom_image.SeriesDescription not in ["Stress", "Rest", "Stress_2"]:
Tamino Huxohl's avatar
Tamino Huxohl committed
            logger.warning(
                f"Skip projection with unknown protocol [{dicom_image.SeriesDescription}]"
Tamino Huxohl's avatar
Tamino Huxohl committed
            )
Tamino Huxohl's avatar
Tamino Huxohl committed

def is_recon_type(
    scatter_corrected: bool, attenuation_corrected: bool
) -> Callable[[DICOM], bool]:
    Get a filter function for reconstructions that are (non-)scatter and/or (non-)attenuation corrected.
    Parameters
    ----------
    scatter_corrected: bool
        if the filter should only return true for scatter corrected reconstructions
    attenuation_corrected: bool
        if the filter should only return true for attenuation corrected reconstructions

    Returns
    -------
    Callable[DICOM, bool]
        a filter function that returns true if the DICOM image has the specified corrections
    if scatter_corrected and attenuation_corrected:
        filter_str = " SC - AC "
    elif not scatter_corrected and attenuation_corrected:
        filter_str = " NoSC - AC "
    elif scatter_corrected and not attenuation_corrected:
        filter_str = " SC  - NoAC "
    elif not scatter_corrected and not attenuation_corrected:
        filter_str = " NoSC - NoAC "
    return lambda dicom: filter_str in dicom.SeriesDescription


Tamino Huxohl's avatar
Tamino Huxohl committed
def find_reconstruction(
    dicom_images: List[DICOM],
    projection: DICOM,
Tamino Huxohl's avatar
Tamino Huxohl committed
    scatter_corrected: bool,
    attenuation_corrected: bool,
) -> List[DICOM]:
    Find all reconstructions in a list of DICOM images of a study belonging to a projection.

    Parameters
    ----------
    dicom_images: lost of DICOM
        DICOM images belonging to the study
    projection: DICOM
        the DICOM image of the projection the reconstructions belong to
    scatter_corrected: bool
        if it should be searched for a scatter corrected reconstruction
    attenuation_corrected: bool
        if it should be searched fo a attenuation corrected reconstruction

    Returns
    -------
    DICOM
        the according reconstruction
    """
    protocol = get_protocol(projection)

    _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    _filter = filter(lambda image: protocol in image.SeriesDescription, _filter)
Tamino Huxohl's avatar
Tamino Huxohl committed
    _filter = filter(
        lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
    )
    _filter = filter(
        lambda image: "CT" not in image.SeriesDescription, _filter
    )  # remove µ-maps
    _filter = filter(is_recon_type(scatter_corrected, attenuation_corrected), _filter)
Tamino Huxohl's avatar
Tamino Huxohl committed
    _filter = filter(
        lambda image: DICOMTime.Acquisition.to_datetime(image)
        == DICOMTime.Acquisition.to_datetime(projection),
        _filter,
    )
    dicom_images = list(_filter)
    if len(dicom_images) == 0:
Tamino Huxohl's avatar
Tamino Huxohl committed
        raise ValueError(
            f"No reconstruction with SC={scatter_corrected}, AC={attenuation_corrected} available"
        )
Tamino Huxohl's avatar
Tamino Huxohl committed
        logger.warning(
            f"Multiple reconstructions ({len(dicom_images)}) with SC={scatter_corrected}, AC={attenuation_corrected} for projection {projection.SeriesDescription} of patient {projection.PatientID}"
        )
        dicom_images.sort(
            key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True
        )
    dicom_images: List[DICOM],
    projection: DICOM,
    reconstructions: List[DICOM],
    max_time_diff: int = 30,
) -> DICOM:
    Find an attenuation map in a list of DICOM images of a study belonging to a projection and reconstructions.

    Parameters
    ----------
    dicom_images: list of DICOM
        the list of DICOM images belonging to the study
    projection: DICOM
        the DICOM image of the projection
    reconstructions: DICOM
        DICOM images of reconstructions belonging to the projection
    max_time_diff: int, optional
        filter out DICOM files which differ more than this value
        in series time to the reconstructions

    Returns
    -------
    DICOM
        the according attenuation map
Tamino Huxohl's avatar
Tamino Huxohl committed
    recon_times = list(
        map(lambda recon: DICOMTime.Series.to_datetime(recon), reconstructions)
    )
    _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
    _filter = filter(lambda image: protocol in image.SeriesDescription, _filter)
Tamino Huxohl's avatar
Tamino Huxohl committed
    _filter = filter(
        lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
    )
    _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter)
Tamino Huxohl's avatar
Tamino Huxohl committed
    _filter = filter(
        lambda image: any(
            map(
                lambda recon_time: (
                    DICOMTime.Series.to_datetime(image) - recon_time
                ).seconds
                < max_time_diff,
Tamino Huxohl's avatar
Tamino Huxohl committed
                recon_times,
            )
        ),
        _filter,
    )
    if len(dicom_images) == 0:
        raise ValueError(f"No Attenuation map available")

    # sort oldest to be first
    if len(dicom_images) > 1:
Tamino Huxohl's avatar
Tamino Huxohl committed
        logger.warning(
            f"Multiple attenuation maps ({len(dicom_images)}) for projection {projection.SeriesDescription} of patient {projection.PatientID}"
        )
        dicom_images.sort(
            key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True
        )
def get_relevant_images(patient: DICOM, dicom_dir: str) -> List[DICOM]:
    Parameters
    ----------
    patient: DICOM
        DICOM dataset of a patient
    dicom_dir: str
        the directory of the DICOM files

    Returns
    -------
    list of DICOM
        all relevant DICOM images
    """
    # get all myocardial scintigraphy studies
    studies = list(
        filter(
            lambda child: child.DirectoryRecordType == "STUDY",
            patient.children,
    # extract all DICOM images
    dicom_images = []
    for study in studies:
        series = list(
            filter(
                lambda child: child.DirectoryRecordType == "SERIES",
                study.children,
            )
        )
        for _series in series:
            images = list(
                filter(
                    lambda child: child.DirectoryRecordType == "IMAGE",
                    _series.children,
                )
            )

            # all SPECT data is stored as a single 3D array which means that it is a series with a single image
            # this is not the case for CTs, which are skipped here
            if len(images) != 1:
                continue

            images = list(
                map(
                    lambda image: pydicom.dcmread(
                        os.path.join(
                            dicom_dir_by_patient[patient.PatientID],
                            *image.ReferencedFileID,
                        ),
                        stop_before_pixels=True,
                    ),
                    images,
                )
            )

            if len(images) == 0:
                continue

            dicom_images.append(images[0])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Prepare a dataset from DICOM directories",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "dicom_dirs",
        type=str,
        nargs="+",
        help="paths to DICOMDIR files or directories containing one of them",
    )
    parser.add_argument(
        "--dataset_dir",
        type=str,
        required=True,
        help="directory where images, meta-information and the logs are stored",
    )
    parser.add_argument(
        "--images_dir",
        type=str,
        default="images",
        help="sub-directory of --dataset_dir where images are stored",
    )
    parser.add_argument(
        "--meta_csv",
        type=str,
        default="meta.csv",
        help="CSV file under --dataset_dir where meta-information is stored",
    )
    add_logging_args(
        parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
    )
    args = parser.parse_args()

    args.dicom_dirs = [
        (os.path.dirname(_file) if os.path.isfile(_file) else _file)
        for _file in args.dicom_dirs
    ]
    args.images_dir = os.path.join(args.dataset_dir, args.images_dir)
    args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv)
    args.logfile = os.path.join(args.dataset_dir, args.logfile)

    if not os.path.exists(args.dataset_dir):
        os.mkdir(args.dataset_dir)

    if not os.path.exists(args.images_dir):
        os.mkdir(args.images_dir)

    global logger
    logger = get_logger_by_args(args)
    try:
        patients = []
        dicom_dir_by_patient: Dict[str, str] = {}
        for dicom_dir in args.dicom_dirs:
            dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
            for patient in dataset.patient_records:
                assert (
                    patient.PatientID not in dicom_dir_by_patient
                ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})"
                dicom_dir_by_patient[patient.PatientID] = dicom_dir
                patients.append(patient)

        _id = 1
        if os.path.exists(args.meta_csv):
            data = pd.read_csv(args.meta_csv)
            _id = int(data[headers.id].max())
        else:
            data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))

        for i, patient in enumerate(patients, start=1):
Tamino Huxohl's avatar
Tamino Huxohl committed
            logger.debug(
                f"Process patient {str(i):>3}/{len(patients)} - {patient.PatientName.given_name}, {patient.PatientName.family_name}:"
            )
Tamino Huxohl's avatar
Tamino Huxohl committed
            dicom_images = get_relevant_images(
                patient, dicom_dir_by_patient[patient.PatientID]
            )
Tamino Huxohl's avatar
Tamino Huxohl committed
            logger.debug(
                f"- Found {len(projections)}: projections with protocols {list(map(lambda p: get_protocol(p), projections))}"
            )

            for projection in projections:
                protocol = get_protocol(projection)

                reconstructions = []
Tamino Huxohl's avatar
Tamino Huxohl committed
                recon_headers = [
                    headers.file_recon_nac_nsc,
                    headers.file_recon_nac_sc,
                    headers.file_recon_ac_nsc,
                    headers.file_recon_ac_sc,
                ]
                recon_postfixes = [
                    "recon_nac_nsc",
                    "recon_nac_sc",
                    "recon_ac_nsc",
                    "recon_ac_sc",
                ]
                for ac, sc in [
                    (False, False),
                    (False, True),
                    (True, False),
                    (True, True),
                ]:
                    recon = find_reconstruction(
                        dicom_images,
                        projection,
                        scatter_corrected=sc,
                        attenuation_corrected=ac,
                    )
                    reconstructions.append(recon)
                mu_map = find_attenuation_map(dicom_images, projection, reconstructions)

                # extract pixel spacings and assert that they are equal for all reconstruction images
                _map_lists = map(
                    lambda image: [*image.PixelSpacing, image.SliceThickness],
                    [*reconstructions, mu_map],
                _map_lists = map(
                    lambda pixel_spacing: list(map(float, pixel_spacing)),
                    _map_lists,
                _map_ndarrays = map(
                    lambda pixel_spacing: np.array(pixel_spacing), _map_lists
                )
                pixel_spacings = list(_map_ndarrays)
                _equal = all(
                    map(
                        lambda pixel_spacing: (
                            pixel_spacing == pixel_spacings[0]
                        ).all(),
                        pixel_spacings,
                )
                assert (
                    _equal
                ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
                pixel_spacing = pixel_spacings[0]

                # use the shape with the fewest slices, all other images will be aligned to that
                _map_lists = map(
                    lambda image: [
                        image.Rows,
                        image.Columns,
                        image.NumberOfSlices,
                    ],
                    [*reconstructions, mu_map],
                )
Tamino Huxohl's avatar
Tamino Huxohl committed
                _map_lists = map(lambda shape: list(map(int, shape)), _map_lists)
                _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
                shapes = list(_map_ndarrays)
                shapes.sort(key=lambda shape: shape[2])
                shape = shapes[0]

                # extract and sort energy windows
Tamino Huxohl's avatar
Tamino Huxohl committed
                energy_windows = projection.EnergyWindowInformationSequence
                energy_windows = map(
                    lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
                )
                energy_windows = map(
                    lambda ew: (
                        float(ew.EnergyWindowLowerLimit),
                        float(ew.EnergyWindowUpperLimit),
                    ),
                    energy_windows,
                )
                energy_windows = list(energy_windows)
                energy_windows.sort(key=lambda ew: ew[0], reverse=True)

                row = {
                    headers.id: _id,
                    headers.patient_id: projection.PatientID,
                    headers.age: parse_age(projection.PatientAge),
                    headers.sex: projection.PatientSex,
                    headers.weight: float(projection.PatientWeight),
                    headers.size: float(projection.PatientSize),
                    headers.patient_position: projection.PatientPosition,
Tamino Huxohl's avatar
Tamino Huxohl committed
                    headers.datetime_acquisition: DICOMTime.Series.to_datetime(
                        projection
                    ),
                    headers.datetime_reconstruction: DICOMTime.Series.to_datetime(
                        reconstructions[0]
                    ),
                    headers.pixel_spacing_x: pixel_spacing[0],
                    headers.pixel_spacing_y: pixel_spacing[1],
                    headers.pixel_spacing_z: pixel_spacing[2],
                    headers.shape_x: shape[0],
                    headers.shape_y: shape[1],
                    headers.shape_z: shape[2],
Tamino Huxohl's avatar
Tamino Huxohl committed
                    headers.radiopharmaceutical: projection.RadiopharmaceuticalInformationSequence[
                        0
                    ].Radiopharmaceutical,
                    headers.radionuclide_dose: projection.RadiopharmaceuticalInformationSequence[
                        0
                    ].RadionuclideTotalDose,
                    headers.radionuclide_code: projection.RadiopharmaceuticalInformationSequence[
                        0
                    ]
                    .RadionuclideCodeSequence[0]
                    .CodeValue,
                    headers.radionuclide_meaning: projection.RadiopharmaceuticalInformationSequence[
                        0
                    ]
                    .RadionuclideCodeSequence[0]
                    .CodeMeaning,
                    headers.energy_window_peak_lower: energy_windows[0][0],
                    headers.energy_window_peak_upper: energy_windows[0][1],
                    headers.energy_window_scatter_lower: energy_windows[1][0],
                    headers.energy_window_scatter_upper: energy_windows[1][1],
                    headers.detector_count: len(projection.DetectorInformationSequence),
Tamino Huxohl's avatar
Tamino Huxohl committed
                    headers.collimator_type: projection.DetectorInformationSequence[
                        0
                    ].CollimatorType,
                    headers.rotation_start: float(
                        projection.RotationInformationSequence[0].StartAngle
                    ),
                    headers.rotation_step: float(
                        projection.RotationInformationSequence[0].AngularStep
                    ),
                    headers.rotation_scan_arc: float(
                        projection.RotationInformationSequence[0].ScanArc
                    ),
                }

                _filename_base = f"{_id:04d}-{protocol.lower()}"
                _ext = "dcm"
Tamino Huxohl's avatar
Tamino Huxohl committed
                img_headers = [
                    headers.file_projection,
                    *recon_headers,
                    headers.file_mu_map,
                ]
                img_postfixes = ["projection", *recon_postfixes, "mu_map"]
                images = [projection, *reconstructions, mu_map]
                for header, image, postfix in zip(img_headers, images, img_postfixes):
                    _image = pydicom.dcmread(image.filename)
                    filename = f"{_filename_base}-{postfix}.{_ext}"
Tamino Huxohl's avatar
Tamino Huxohl committed
                    pydicom.dcmwrite(os.path.join(args.images_dir, filename), _image)
                _id += 1
                row = pd.DataFrame(row, index=[0])
                data = pd.concat((data, row), ignore_index=True)

        data.to_csv(args.meta_csv, index=False)
    except Exception as e:
        logger.error(e)
        raise e