From 7e9d42da2d27b403b5c10208ff1bec95522ea2d0 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 25 Aug 2022 16:39:10 +0200
Subject: [PATCH] adapt preparse script to new study descriptions

---
 mu_map/data/prepare.py | 517 +++++++++++++++++++++--------------------
 1 file changed, 268 insertions(+), 249 deletions(-)

diff --git a/mu_map/data/prepare.py b/mu_map/data/prepare.py
index a6ac8f2..2037fd7 100644
--- a/mu_map/data/prepare.py
+++ b/mu_map/data/prepare.py
@@ -11,6 +11,9 @@ import pydicom
 from mu_map.logging import add_logging_args, get_logger_by_args
 
 
+STUDY_DESCRIPTION = "µ-map_study"
+
+
 class MyocardialProtocol(Enum):
     Stress = 1
     Rest = 2
@@ -44,6 +47,7 @@ headers.collimator_type = "collimator_type"
 headers.rotation_start = "rotation_start"
 headers.rotation_step = "rotation_step"
 headers.rotation_scan_arc = "rotation_scan_arc"
+headers.file_projection = "file_projection"
 headers.file_recon_ac = "file_recon_ac"
 headers.file_recon_no_ac = "file_recon_no_ac"
 headers.file_mu_map = "file_mu_map"
@@ -128,22 +132,25 @@ def get_reconstruction(
     """
     _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
     _filter = filter(
-        lambda image: protocol.name in image.SeriesDescription, dicom_images
+        lambda image: protocol.name in image.SeriesDescription, _filter
+    )
+    _filter = filter(
+        lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
     )
 
     if corrected:
         _filter = filter(
             lambda image: "AC" in image.SeriesDescription
             and "NoAC" not in image.SeriesDescription,
-            dicom_images,
+            _filter,
         )
     else:
-        _filter = filter(lambda image: "NoAC" in image.SeriesDescription, dicom_images)
+        _filter = filter(lambda image: "NoAC" in image.SeriesDescription, _filter)
 
     # for SPECT reconstructions created in clinical studies this value exists and is set to 'APEX_TO_BASE'
     # for the reconstructions with attenuation maps it does not exist
     _filter = filter(
-        lambda image: not hasattr(image, "SliceProgressionDirection"), dicom_images
+        lambda image: not hasattr(image, "SliceProgressionDirection"), _filter
     )
 
     dicom_images = list(_filter)
@@ -171,9 +178,10 @@ def get_attenuation_map(
     """
     _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
     _filter = filter(
-        lambda image: protocol.name in image.SeriesDescription, dicom_images
+        lambda image: protocol.name in image.SeriesDescription, _filter
     )
-    _filter = filter(lambda image: "µ-map" in image.SeriesDescription, dicom_images)
+    _filter = filter(lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter)
+    _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter)
     dicom_images = list(_filter)
     dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True)
 
@@ -260,263 +268,274 @@ if __name__ == "__main__":
     global logger
     logger = get_logger_by_args(args)
 
-    patients = []
-    dicom_dir_by_patient: Dict[str, str] = {}
-    for dicom_dir in args.dicom_dirs:
-        dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
-        for patient in dataset.patient_records:
-            assert (
-                patient.PatientID not in dicom_dir_by_patient
-            ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})"
-            dicom_dir_by_patient[patient.PatientID] = dicom_dir
-            patients.append(patient)
-
-    _id = 1
-    if os.path.exists(args.meta_csv):
-        data = pd.read_csv(args.meta_csv)
-        _id = int(data[headers.id].max())
-    else:
-        data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))
-
-    for i, patient in enumerate(patients, start=1):
-        logger.debug(f"Process patient {str(i):>3}/{len(patients)}:")
-
-        # get all myocardial scintigraphy studies
-        studies = list(
-            filter(
-                lambda child: child.DirectoryRecordType == "STUDY"
-                and child.StudyDescription == "Myokardszintigraphie",
-                patient.children,
-            )
-        )
-
-        # extract all dicom images
-        dicom_images = []
-        for study in studies:
-            series = list(
+    try:
+        patients = []
+        dicom_dir_by_patient: Dict[str, str] = {}
+        for dicom_dir in args.dicom_dirs:
+            dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
+            for patient in dataset.patient_records:
+                assert (
+                    patient.PatientID not in dicom_dir_by_patient
+                ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})"
+                dicom_dir_by_patient[patient.PatientID] = dicom_dir
+                patients.append(patient)
+
+        _id = 1
+        if os.path.exists(args.meta_csv):
+            data = pd.read_csv(args.meta_csv)
+            _id = int(data[headers.id].max())
+        else:
+            data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))
+
+        for i, patient in enumerate(patients, start=1):
+            logger.debug(f"Process patient {str(i):>3}/{len(patients)}:")
+
+            # get all myocardial scintigraphy studies
+            studies = list(
                 filter(
-                    lambda child: child.DirectoryRecordType == "SERIES", study.children
+                    lambda child: child.DirectoryRecordType == "STUDY",
+                    # and child.StudyDescription == "Myokardszintigraphie",
+                    patient.children,
                 )
             )
-            for _series in series:
-                images = list(
+
+            # extract all dicom images
+            dicom_images = []
+            for study in studies:
+                series = list(
                     filter(
-                        lambda child: child.DirectoryRecordType == "IMAGE",
-                        _series.children,
+                        lambda child: child.DirectoryRecordType == "SERIES", study.children
                     )
                 )
+                for _series in series:
+                    images = list(
+                        filter(
+                            lambda child: child.DirectoryRecordType == "IMAGE",
+                            _series.children,
+                        )
+                    )
 
-                # all SPECT data is stored as a single 3D array which means that it is a series with a single image
-                # this is not the case for CTs, which are skipped here
-                if len(images) != 1:
-                    continue
-
-                images = list(
-                    map(
-                        lambda image: pydicom.dcmread(
-                            os.path.join(
-                                dicom_dir_by_patient[patient.PatientID],
-                                *image.ReferencedFileID,
+                    # all SPECT data is stored as a single 3D array which means that it is a series with a single image
+                    # this is not the case for CTs, which are skipped here
+                    if len(images) != 1:
+                        continue
+
+                    images = list(
+                        map(
+                            lambda image: pydicom.dcmread(
+                                os.path.join(
+                                    dicom_dir_by_patient[patient.PatientID],
+                                    *image.ReferencedFileID,
+                                ),
+                                stop_before_pixels=True,
                             ),
-                            stop_before_pixels=True,
-                        ),
-                        images,
+                            images,
+                        )
                     )
-                )
-
-                if len(images) == 0:
-                    continue
-
-                dicom_images.append(images[0])
 
-            for protocol in MyocardialProtocol:
-                if (
-                    len(
-                        data[
-                            (data[headers.patient_id] == patient.PatientID)
-                            & (data[headers.protocol] == protocol.name)
-                        ]
+                    if len(images) == 0:
+                        continue
+
+                    dicom_images.append(images[0])
+
+                for protocol in MyocardialProtocol:
+                    if (
+                        len(
+                            data[
+                                (data[headers.patient_id] == patient.PatientID)
+                                & (data[headers.protocol] == protocol.name)
+                            ]
+                        )
+                        > 0
+                    ):
+                        logger.info(
+                            f"Skip {patient.PatientID}:{protocol.name} since it is already contained in the dataset"
+                        )
+                        continue
+
+                    try:
+                        projection_image = get_projection(dicom_images, protocol=protocol)
+                        recon_ac = get_reconstruction(
+                            dicom_images, protocol=protocol, corrected=True
+                        )
+                        recon_noac = get_reconstruction(
+                            dicom_images, protocol=protocol, corrected=False
+                        )
+                        attenuation_map = get_attenuation_map(
+                            dicom_images, protocol=protocol
+                        )
+                    except ValueError as e:
+                        logger.info(f"Skip {patient.PatientID}:{protocol.name} because {e}")
+                        continue
+
+                    recon_images = [recon_ac, recon_noac, attenuation_map]
+
+                    # extract date times and assert that they are equal for all reconstruction images
+                    datetimes = list(map(parse_series_time, recon_images))
+                    _datetimes = sorted(datetimes, reverse=True)
+                    _datetimes_delta = list(map(lambda dt: _datetimes[0] - dt, _datetimes))
+                    _equal = all(
+                        map(lambda dt: dt < timedelta(seconds=300), _datetimes_delta)
                     )
-                    > 0
-                ):
-                    logger.info(
-                        f"Skip {patient.PatientID}:{protocol.name} since it is already contained in the dataset"
+                    assert (
+                        _equal
+                    ), f"Not all dates and times of the reconstructions are equal: {datetimes}"
+
+                    # extract pixel spacings and assert that they are equal for all reconstruction images
+                    _map_lists = map(
+                        lambda image: [*image.PixelSpacing, image.SliceThickness],
+                        recon_images,
                     )
-                    continue
-
-                try:
-                    projection_image = get_projection(dicom_images, protocol=protocol)
-                    recon_ac = get_reconstruction(
-                        dicom_images, protocol=protocol, corrected=True
+                    _map_lists = map(
+                        lambda pixel_spacing: list(map(float, pixel_spacing)), _map_lists
                     )
-                    recon_noac = get_reconstruction(
-                        dicom_images, protocol=protocol, corrected=False
+                    _map_ndarrays = map(
+                        lambda pixel_spacing: np.array(pixel_spacing), _map_lists
                     )
-                    attenuation_map = get_attenuation_map(
-                        dicom_images, protocol=protocol
+                    pixel_spacings = list(_map_ndarrays)
+                    _equal = all(
+                        map(
+                            lambda pixel_spacing: (
+                                pixel_spacing == pixel_spacings[0]
+                            ).all(),
+                            pixel_spacings,
+                        )
                     )
-                except ValueError as e:
-                    logger.info(f"Skip {patient.PatientID}:{protocol.name} because {e}")
-                    continue
-
-                recon_images = [recon_ac, recon_noac, attenuation_map]
-
-                # extract date times and assert that they are equal for all reconstruction images
-                datetimes = list(map(parse_series_time, recon_images))
-                _datetimes = sorted(datetimes, reverse=True)
-                _datetimes_delta = list(map(lambda dt: _datetimes[0] - dt, _datetimes))
-                _equal = all(
-                    map(lambda dt: dt < timedelta(seconds=300), _datetimes_delta)
-                )
-                assert (
-                    _equal
-                ), f"Not all dates and times of the reconstructions are equal: {datetimes}"
-
-                # extract pixel spacings and assert that they are equal for all reconstruction images
-                _map_lists = map(
-                    lambda image: [*image.PixelSpacing, image.SliceThickness],
-                    recon_images,
-                )
-                _map_lists = map(
-                    lambda pixel_spacing: list(map(float, pixel_spacing)), _map_lists
-                )
-                _map_ndarrays = map(
-                    lambda pixel_spacing: np.array(pixel_spacing), _map_lists
-                )
-                pixel_spacings = list(_map_ndarrays)
-                _equal = all(
-                    map(
-                        lambda pixel_spacing: (
-                            pixel_spacing == pixel_spacings[0]
-                        ).all(),
-                        pixel_spacings,
+                    assert (
+                        _equal
+                    ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
+                    pixel_spacing = pixel_spacings[0]
+
+                    # extract shapes and assert that they are equal for all reconstruction images
+                    _map_lists = map(
+                        lambda image: [image.Rows, image.Columns, image.NumberOfSlices],
+                        recon_images,
+                    )
+                    _map_lists = map(lambda shape: list(map(int, shape)), _map_lists)
+                    _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
+                    shapes = list(_map_ndarrays)
+                    _equal = all(map(lambda shape: (shape == shapes[0]).all(), shapes))
+                    # assert _equal, f"Not all shapes of the reconstructions are equal: {shapes}"
+                    # print(shapes)
+                    shape = shapes[0]
+
+                    # exctract and sort energy windows
+                    energy_windows = projection_image.EnergyWindowInformationSequence
+                    energy_windows = map(
+                        lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
+                    )
+                    energy_windows = map(
+                        lambda ew: (
+                            float(ew.EnergyWindowLowerLimit),
+                            float(ew.EnergyWindowUpperLimit),
+                        ),
+                        energy_windows,
+                    )
+                    energy_windows = list(energy_windows)
+                    energy_windows.sort(key=lambda ew: ew[0], reverse=True)
+
+                    # re-read images with pixel-level data and save accordingly
+                    projection_image = pydicom.dcmread(projection_image.filename)
+                    recon_ac = pydicom.dcmread(recon_ac.filename)
+                    recon_noac = pydicom.dcmread(recon_noac.filename)
+                    attenuation_map = pydicom.dcmread(attenuation_map.filename)
+
+                    _filename_base = f"{_id:04d}-{protocol.name.lower()}"
+                    _ext = "dcm"
+                    _filename_projection = f"{_filename_base}-{args.prefix_projection}.{_ext}"
+                    _filename_recon_ac = f"{_filename_base}-{args.prefix_recon_ac}.{_ext}"
+                    _filename_recon_no_ac = f"{_filename_base}-{args.prefix_recon_no_ac}.{_ext}"
+                    _filename_mu_map = f"{_filename_base}-{args.prefix_mu_map}.{_ext}"
+                    pydicom.dcmwrite(
+                        os.path.join(
+                            args.images_dir,
+                            _filename_projection,
+                        ),
+                        projection_image,
+                    )
+                    pydicom.dcmwrite(
+                        os.path.join(
+                            args.images_dir,
+                            _filename_recon_ac,
+                        ),
+                        recon_ac,
+                    )
+                    pydicom.dcmwrite(
+                        os.path.join(
+                            args.images_dir,
+                            _filename_recon_no_ac,
+                        ),
+                        recon_noac,
+                    )
+                    pydicom.dcmwrite(
+                        os.path.join(
+                            args.images_dir,
+                            _filename_mu_map,
+                        ),
+                        attenuation_map,
                     )
-                )
-                assert (
-                    _equal
-                ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
-                pixel_spacing = pixel_spacings[0]
-
-                # extract shapes and assert that they are equal for all reconstruction images
-                _map_lists = map(
-                    lambda image: [image.Rows, image.Columns, image.NumberOfSlices],
-                    recon_images,
-                )
-                _map_lists = map(lambda shape: list(map(int, shape)), _map_lists)
-                _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
-                shapes = list(_map_ndarrays)
-                _equal = all(map(lambda shape: (shape == shapes[0]).all(), shapes))
-                # assert _equal, f"Not all shapes of the reconstructions are equal: {shapes}"
-                # print(shapes)
-                shape = shapes[0]
-
-                # exctract and sort energy windows
-                energy_windows = projection_image.EnergyWindowInformationSequence
-                energy_windows = map(
-                    lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
-                )
-                energy_windows = map(
-                    lambda ew: (
-                        float(ew.EnergyWindowLowerLimit),
-                        float(ew.EnergyWindowUpperLimit),
-                    ),
-                    energy_windows,
-                )
-                energy_windows = list(energy_windows)
-                energy_windows.sort(key=lambda ew: ew[0], reverse=True)
-
-                # re-read images with pixel-level data and save accordingly
-                projection_image = pydicom.dcmread(projection_image.filename)
-                recon_ac = pydicom.dcmread(recon_ac.filename)
-                recon_noac = pydicom.dcmread(recon_noac.filename)
-                attenuation_map = pydicom.dcmread(attenuation_map.filename)
-
-                pydicom.dcmwrite(
-                    os.path.join(
-                        args.images_dir,
-                        f"{_id:04d}-{protocol.name.lower()}-{args.prefix_projection}.dcm",
-                    ),
-                    projection_image,
-                )
-                pydicom.dcmwrite(
-                    os.path.join(
-                        args.images_dir,
-                        f"{_id:04d}-{protocol.name.lower()}-{args.prefix_recon_ac}.dcm",
-                    ),
-                    recon_ac,
-                )
-                pydicom.dcmwrite(
-                    os.path.join(
-                        args.images_dir,
-                        f"{_id:04d}-{protocol.name.lower()}-{args.prefix_recon_no_ac}.dcm",
-                    ),
-                    recon_noac,
-                )
-                pydicom.dcmwrite(
-                    os.path.join(
-                        args.images_dir,
-                        f"{_id:04d}-{protocol.name.lower()}-{args.prefix_mu_map}.dcm",
-                    ),
-                    attenuation_map,
-                )
 
-                row = {
-                    headers.id: _id,
-                    headers.patient_id: projection_image.PatientID,
-                    headers.age: parse_age(projection_image.PatientAge),
-                    headers.weight: float(projection_image.PatientWeight),
-                    headers.size: float(projection_image.PatientSize),
-                    headers.protocol: protocol.name,
-                    headers.datetime_acquisition: parse_series_time(projection_image),
-                    headers.datetime_reconstruction: datetimes[0],
-                    headers.pixel_spacing_x: pixel_spacing[0],
-                    headers.pixel_spacing_y: pixel_spacing[1],
-                    headers.pixel_spacing_z: pixel_spacing[2],
-                    headers.shape_x: shape[0],
-                    headers.shape_y: shape[1],
-                    headers.shape_z: shape[2],
-                    headers.radiopharmaceutical: projection_image.RadiopharmaceuticalInformationSequence[
-                        0
-                    ].Radiopharmaceutical,
-                    headers.radionuclide_dose: projection_image.RadiopharmaceuticalInformationSequence[
-                        0
-                    ].RadionuclideTotalDose,
-                    headers.radionuclide_code: projection_image.RadiopharmaceuticalInformationSequence[
-                        0
-                    ]
-                    .RadionuclideCodeSequence[0]
-                    .CodeValue,
-                    headers.radionuclide_meaning: projection_image.RadiopharmaceuticalInformationSequence[
-                        0
-                    ]
-                    .RadionuclideCodeSequence[0]
-                    .CodeMeaning,
-                    headers.energy_window_peak_lower: energy_windows[0][0],
-                    headers.energy_window_peak_upper: energy_windows[0][1],
-                    headers.energy_window_scatter_lower: energy_windows[1][0],
-                    headers.energy_window_scatter_upper: energy_windows[1][1],
-                    headers.detector_count: len(
-                        projection_image.DetectorInformationSequence
-                    ),
-                    headers.collimator_type: projection_image.DetectorInformationSequence[
-                        0
-                    ].CollimatorType,
-                    headers.rotation_start: float(
-                        projection_image.RotationInformationSequence[0].StartAngle
-                    ),
-                    headers.rotation_step: float(
-                        projection_image.RotationInformationSequence[0].AngularStep
-                    ),
-                    headers.rotation_scan_arc: float(
-                        projection_image.RotationInformationSequence[0].ScanArc
-                    ),
-                    headers.file_recon_ac: "filename_recon_ac.dcm",
-                    headers.file_recon_no_ac: "filename_recon_no_ac.dcm",
-                    headers.file_mu_map: "filanem_mu_map.dcm",
-                }
-                _id += 1
-
-                row = pd.DataFrame(row, index=[0])
-                data = pd.concat((data, row), ignore_index=True)
-
-data.to_csv(args.meta_csv, index=False)
+                    row = {
+                        headers.id: _id,
+                        headers.patient_id: projection_image.PatientID,
+                        headers.age: parse_age(projection_image.PatientAge),
+                        headers.weight: float(projection_image.PatientWeight),
+                        headers.size: float(projection_image.PatientSize),
+                        headers.protocol: protocol.name,
+                        headers.datetime_acquisition: parse_series_time(projection_image),
+                        headers.datetime_reconstruction: datetimes[0],
+                        headers.pixel_spacing_x: pixel_spacing[0],
+                        headers.pixel_spacing_y: pixel_spacing[1],
+                        headers.pixel_spacing_z: pixel_spacing[2],
+                        headers.shape_x: shape[0],
+                        headers.shape_y: shape[1],
+                        headers.shape_z: shape[2],
+                        headers.radiopharmaceutical: projection_image.RadiopharmaceuticalInformationSequence[
+                            0
+                        ].Radiopharmaceutical,
+                        headers.radionuclide_dose: projection_image.RadiopharmaceuticalInformationSequence[
+                            0
+                        ].RadionuclideTotalDose,
+                        headers.radionuclide_code: projection_image.RadiopharmaceuticalInformationSequence[
+                            0
+                        ]
+                        .RadionuclideCodeSequence[0]
+                        .CodeValue,
+                        headers.radionuclide_meaning: projection_image.RadiopharmaceuticalInformationSequence[
+                            0
+                        ]
+                        .RadionuclideCodeSequence[0]
+                        .CodeMeaning,
+                        headers.energy_window_peak_lower: energy_windows[0][0],
+                        headers.energy_window_peak_upper: energy_windows[0][1],
+                        headers.energy_window_scatter_lower: energy_windows[1][0],
+                        headers.energy_window_scatter_upper: energy_windows[1][1],
+                        headers.detector_count: len(
+                            projection_image.DetectorInformationSequence
+                        ),
+                        headers.collimator_type: projection_image.DetectorInformationSequence[
+                            0
+                        ].CollimatorType,
+                        headers.rotation_start: float(
+                            projection_image.RotationInformationSequence[0].StartAngle
+                        ),
+                        headers.rotation_step: float(
+                            projection_image.RotationInformationSequence[0].AngularStep
+                        ),
+                        headers.rotation_scan_arc: float(
+                            projection_image.RotationInformationSequence[0].ScanArc
+                        ),
+                        headers.file_projection: _filename_projection,
+                        headers.file_recon_ac: _filename_recon_ac,
+                        headers.file_recon_no_ac: _filename_recon_no_ac,
+                        headers.file_mu_map: _filename_mu_map,
+                    }
+                    _id += 1
+
+                    row = pd.DataFrame(row, index=[0])
+                    data = pd.concat((data, row), ignore_index=True)
+
+        data.to_csv(args.meta_csv, index=False)
+    except Exception as e:
+        logger.error(e)
+        raise e
-- 
GitLab