From ae7a78ba551d907ae257c5999a71c90df9bb49fb Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 1 Sep 2022 10:50:43 +0200
Subject: [PATCH] update prepare script to deal with new strcuture (scatter
 corrections and attenuation corrections)

---
 mu_map/data/prepare.py | 505 ++++++++++++++++++++++-------------------
 1 file changed, 270 insertions(+), 235 deletions(-)

diff --git a/mu_map/data/prepare.py b/mu_map/data/prepare.py
index 3f5adfc..e123746 100644
--- a/mu_map/data/prepare.py
+++ b/mu_map/data/prepare.py
@@ -48,8 +48,10 @@ headers.rotation_start = "rotation_start"
 headers.rotation_step = "rotation_step"
 headers.rotation_scan_arc = "rotation_scan_arc"
 headers.file_projection = "file_projection"
-headers.file_recon_ac = "file_recon_ac"
-headers.file_recon_no_ac = "file_recon_no_ac"
+headers.file_recon_ac_sc = "file_recon_ac_sc"
+headers.file_recon_nac_sc = "file_recon_nac_sc"
+headers.file_recon_ac_nsc = "file_recon_ac_nsc"
+headers.file_recon_nac_nsc = "file_recon_nac_nsc"
 headers.file_mu_map = "file_mu_map"
 
 
@@ -93,7 +95,7 @@ def parse_age(patient_age: str) -> int:
     return int(_num)
 
 
-def get_projection(
+def get_projections(
     dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol
 ) -> pydicom.dataset.FileDataset:
     """
@@ -107,18 +109,17 @@ def get_projection(
     _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter)
     dicom_images = list(_filter)
 
-    if len(dicom_images) != 1:
-        raise ValueError(
-            f"No or multiple projections {len(dicom_images)} for protocol {protocol.name} available"
-        )
+    if len(dicom_images) == 0:
+        raise ValueError(f"No projection for protocol {protocol.name} available")
 
-    return dicom_images[0]
+    return dicom_images
 
 
-def get_reconstruction(
+def get_reconstructions(
     dicom_images: List[pydicom.dataset.FileDataset],
     protocol: MyocardialProtocol,
-    corrected: bool = True,
+    scatter_corrected: bool,
+    attenuation_corrected: bool,
 ) -> pydicom.dataset.FileDataset:
     """
     Extract a SPECT reconstruction from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
@@ -127,45 +128,42 @@ def get_reconstruction(
 
     :param dicom_images: list of DICOM images of a study
     :param protocol: the protocol for which the projection images should be extracted
-    :param corrected: extract an attenuation or non-attenuation corrected image
+    :param attenuation_corrected: extract an attenuation or non-attenuation corrected image
+    :param scatter_corrected: extract the image to which scatter correction was applied
     :return: the extracted DICOM image
     """
     _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
-    _filter = filter(
-        lambda image: protocol.name in image.SeriesDescription, _filter
-    )
+    _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter)
     _filter = filter(
         lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
     )
 
-    if corrected:
-        _filter = filter(
-            lambda image: "AC" in image.SeriesDescription
-            and "NoAC" not in image.SeriesDescription,
-            _filter,
-        )
-    else:
-        _filter = filter(lambda image: "NoAC" in image.SeriesDescription, _filter)
+    if scatter_corrected and attenuation_corrected:
+        filter_str = " SC - AC "
+    elif not scatter_corrected and attenuation_corrected:
+        filter_str = " NoSC - AC "
+    elif scatter_corrected and not attenuation_corrected:
+        filter_str = " SC  - NoAC "
+    elif not scatter_corrected and not attenuation_corrected:
+        filter_str = " NoSC - NoAC "
+    _filter = filter(lambda image: filter_str in image.SeriesDescription, _filter)
 
     # for SPECT reconstructions created in clinical studies this value exists and is set to 'APEX_TO_BASE'
     # for the reconstructions with attenuation maps it does not exist
-    _filter = filter(
-        lambda image: not hasattr(image, "SliceProgressionDirection"), _filter
-    )
+    # _filter = filter(
+    # lambda image: not hasattr(image, "SliceProgressionDirection"), _filter
+    # )
 
     dicom_images = list(_filter)
-    dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True)
-
     if len(dicom_images) == 0:
-        _str = "AC" if corrected else "NoAC"
         raise ValueError(
-            f"{_str} Reconstruction for protocol {protocol.name} is not available"
+            f"'{filter_str}' Reconstruction for protocol {protocol.name} is not available"
         )
 
-    return dicom_images[0]
+    return dicom_images
 
 
-def get_attenuation_map(
+def get_attenuation_maps(
     dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol
 ) -> pydicom.dataset.FileDataset:
     """
@@ -177,20 +175,19 @@ def get_attenuation_map(
     :return: the extracted DICOM image
     """
     _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
+    _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter)
     _filter = filter(
-        lambda image: protocol.name in image.SeriesDescription, _filter
+        lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter
     )
-    _filter = filter(lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter)
     _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter)
-    dicom_images = list(_filter)
-    dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True)
 
+    dicom_images = list(_filter)
     if len(dicom_images) == 0:
         raise ValueError(
             f"Attenuation map for protocol {protocol.name} is not available"
         )
 
-    return dicom_images[0]
+    return dicom_images
 
 
 if __name__ == "__main__":
@@ -222,30 +219,6 @@ if __name__ == "__main__":
         default="meta.csv",
         help="CSV file under --dataset_dir where meta-information is stored",
     )
-    parser.add_argument(
-        "--prefix_projection",
-        type=str,
-        default="projection",
-        help="prefix used to store DICOM images of projections - format <id>-<protocol>-<prefix>.dcm",
-    )
-    parser.add_argument(
-        "--prefix_mu_map",
-        type=str,
-        default="mu_map",
-        help="prefix used to store DICOM images of attenuation maps - format <id>-<protocol>-<prefix>.dcm",
-    )
-    parser.add_argument(
-        "--prefix_recon_ac",
-        type=str,
-        default="recon_ac",
-        help="prefix used to store DICOM images of reconstructions with attenuation correction - format <id>-<protocol>-<prefix>.dcm",
-    )
-    parser.add_argument(
-        "--prefix_recon_no_ac",
-        type=str,
-        default="recon_no_ac",
-        help="prefix used to store DICOM images of reconstructions without attenuation correction - format <id>-<protocol>-<prefix>.dcm",
-    )
     add_logging_args(
         parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
     )
@@ -294,7 +267,7 @@ if __name__ == "__main__":
             studies = list(
                 filter(
                     lambda child: child.DirectoryRecordType == "STUDY",
-                    # and child.StudyDescription == "Myokardszintigraphie",
+                    # and child.StudyDescription == "Myokardszintigraphie", # filter is disabled because there is a study without this description and only such studies are exported anyway
                     patient.children,
                 )
             )
@@ -304,7 +277,8 @@ if __name__ == "__main__":
             for study in studies:
                 series = list(
                     filter(
-                        lambda child: child.DirectoryRecordType == "SERIES", study.children
+                        lambda child: child.DirectoryRecordType == "SERIES",
+                        study.children,
                     )
                 )
                 for _series in series:
@@ -342,7 +316,7 @@ if __name__ == "__main__":
                     if (
                         len(
                             data[
-                                (data[headers.patient_id] == patient.PatientID)
+                                (data[headers.patient_id] == int(patient.PatientID))
                                 & (data[headers.protocol] == protocol.name)
                             ]
                         )
@@ -353,185 +327,246 @@ if __name__ == "__main__":
                         )
                         continue
 
+                    extractions = [
+                        {
+                            "function": get_projections,
+                            "kwargs": {},
+                            "reconstruction": False,
+                            "prefix": "projection",
+                            "header": headers.file_projection,
+                        },
+                        {
+                            "function": get_reconstructions,
+                            "kwargs": {
+                                "attenuation_corrected": True,
+                                "scatter_corrected": True,
+                            },
+                            "reconstruction": True,
+                            "prefix": "recon_ac_sc",
+                            "header": headers.file_recon_ac_sc,
+                        },
+                        {
+                            "function": get_reconstructions,
+                            "kwargs": {
+                                "attenuation_corrected": True,
+                                "scatter_corrected": False,
+                            },
+                            "reconstruction": True,
+                            "prefix": "recon_ac_nsc",
+                            "header": headers.file_recon_ac_nsc,
+                        },
+                        {
+                            "function": get_reconstructions,
+                            "kwargs": {
+                                "attenuation_corrected": False,
+                                "scatter_corrected": True,
+                            },
+                            "reconstruction": True,
+                            "prefix": "recon_nac_sc",
+                            "header": headers.file_recon_nac_sc,
+                        },
+                        {
+                            "function": get_reconstructions,
+                            "kwargs": {
+                                "attenuation_corrected": False,
+                                "scatter_corrected": False,
+                            },
+                            "reconstruction": True,
+                            "prefix": "recon_nac_nsc",
+                            "header": headers.file_recon_nac_nsc,
+                        },
+                        {
+                            "function": get_attenuation_maps,
+                            "kwargs": {},
+                            "reconstruction": True,
+                            "prefix": "mu_map",
+                            "header": headers.file_mu_map,
+                        },
+                    ]
+
                     try:
-                        projection_image = get_projection(dicom_images, protocol=protocol)
-                        recon_ac = get_reconstruction(
-                            dicom_images, protocol=protocol, corrected=True
-                        )
-                        recon_noac = get_reconstruction(
-                            dicom_images, protocol=protocol, corrected=False
-                        )
-                        attenuation_map = get_attenuation_map(
-                            dicom_images, protocol=protocol
-                        )
+                        for extrac in extractions:
+                            _images = extrac["function"](
+                                dicom_images, protocol=protocol, **extrac["kwargs"]
+                            )
+                            _images.sort(key=parse_series_time)
+                            extrac["images"] = _images
                     except ValueError as e:
-                        logger.info(f"Skip {patient.PatientID}:{protocol.name} because {e}")
+                        logger.info(
+                            f"Skip {patient.PatientID}:{protocol.name} because {e}"
+                        )
                         continue
 
-                    recon_images = [recon_ac, recon_noac, attenuation_map]
-
-                    # extract date times and assert that they are equal for all reconstruction images
-                    datetimes = list(map(parse_series_time, recon_images))
-                    _datetimes = sorted(datetimes, reverse=True)
-                    _datetimes_delta = list(map(lambda dt: _datetimes[0] - dt, _datetimes))
-                    _equal = all(
-                        map(lambda dt: dt < timedelta(seconds=300), _datetimes_delta)
-                    )
-                    assert (
-                        _equal
-                    ), f"Not all dates and times of the reconstructions are equal: {datetimes}"
-
-                    # extract pixel spacings and assert that they are equal for all reconstruction images
-                    _map_lists = map(
-                        lambda image: [*image.PixelSpacing, image.SliceThickness],
-                        recon_images,
-                    )
-                    _map_lists = map(
-                        lambda pixel_spacing: list(map(float, pixel_spacing)), _map_lists
-                    )
-                    _map_ndarrays = map(
-                        lambda pixel_spacing: np.array(pixel_spacing), _map_lists
-                    )
-                    pixel_spacings = list(_map_ndarrays)
-                    _equal = all(
-                        map(
-                            lambda pixel_spacing: (
-                                pixel_spacing == pixel_spacings[0]
-                            ).all(),
-                            pixel_spacings,
-                        )
-                    )
-                    assert (
-                        _equal
-                    ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
-                    pixel_spacing = pixel_spacings[0]
-
-                    # use the shape with the fewest slices, all other images will be aligned to that
-                    _map_lists = map(
-                        lambda image: [image.Rows, image.Columns, image.NumberOfSlices],
-                        recon_images,
-                    )
-                    _map_lists = map(lambda shape: list(map(int, shape)), _map_lists)
-                    _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
-                    shapes = list(_map_ndarrays)
-                    shapes.sort(key=lambda shape: shape[2])
-                    shape = shapes[0]
-
-                    # exctract and sort energy windows
-                    energy_windows = projection_image.EnergyWindowInformationSequence
-                    energy_windows = map(
-                        lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
-                    )
-                    energy_windows = map(
-                        lambda ew: (
-                            float(ew.EnergyWindowLowerLimit),
-                            float(ew.EnergyWindowUpperLimit),
-                        ),
-                        energy_windows,
-                    )
-                    energy_windows = list(energy_windows)
-                    energy_windows.sort(key=lambda ew: ew[0], reverse=True)
-
-                    # re-read images with pixel-level data and save accordingly
-                    projection_image = pydicom.dcmread(projection_image.filename)
-                    recon_ac = pydicom.dcmread(recon_ac.filename)
-                    recon_noac = pydicom.dcmread(recon_noac.filename)
-                    attenuation_map = pydicom.dcmread(attenuation_map.filename)
-
-                    _filename_base = f"{_id:04d}-{protocol.name.lower()}"
-                    _ext = "dcm"
-                    _filename_projection = f"{_filename_base}-{args.prefix_projection}.{_ext}"
-                    _filename_recon_ac = f"{_filename_base}-{args.prefix_recon_ac}.{_ext}"
-                    _filename_recon_no_ac = f"{_filename_base}-{args.prefix_recon_no_ac}.{_ext}"
-                    _filename_mu_map = f"{_filename_base}-{args.prefix_mu_map}.{_ext}"
-                    pydicom.dcmwrite(
-                        os.path.join(
-                            args.images_dir,
-                            _filename_projection,
-                        ),
-                        projection_image,
-                    )
-                    pydicom.dcmwrite(
-                        os.path.join(
-                            args.images_dir,
-                            _filename_recon_ac,
-                        ),
-                        recon_ac,
-                    )
-                    pydicom.dcmwrite(
-                        os.path.join(
-                            args.images_dir,
-                            _filename_recon_no_ac,
-                        ),
-                        recon_noac,
-                    )
-                    pydicom.dcmwrite(
-                        os.path.join(
-                            args.images_dir,
-                            _filename_mu_map,
-                        ),
-                        attenuation_map,
+                    num_images = min(
+                        list(map(lambda extrac: len(extrac["images"]), extractions))
                     )
 
-                    row = {
-                        headers.id: _id,
-                        headers.patient_id: projection_image.PatientID,
-                        headers.age: parse_age(projection_image.PatientAge),
-                        headers.weight: float(projection_image.PatientWeight),
-                        headers.size: float(projection_image.PatientSize),
-                        headers.protocol: protocol.name,
-                        headers.datetime_acquisition: parse_series_time(projection_image),
-                        headers.datetime_reconstruction: datetimes[0],
-                        headers.pixel_spacing_x: pixel_spacing[0],
-                        headers.pixel_spacing_y: pixel_spacing[1],
-                        headers.pixel_spacing_z: pixel_spacing[2],
-                        headers.shape_x: shape[0],
-                        headers.shape_y: shape[1],
-                        headers.shape_z: shape[2],
-                        headers.radiopharmaceutical: projection_image.RadiopharmaceuticalInformationSequence[
-                            0
-                        ].Radiopharmaceutical,
-                        headers.radionuclide_dose: projection_image.RadiopharmaceuticalInformationSequence[
-                            0
-                        ].RadionuclideTotalDose,
-                        headers.radionuclide_code: projection_image.RadiopharmaceuticalInformationSequence[
-                            0
-                        ]
-                        .RadionuclideCodeSequence[0]
-                        .CodeValue,
-                        headers.radionuclide_meaning: projection_image.RadiopharmaceuticalInformationSequence[
-                            0
+                    # ATTENTION: this is a special filter for mu maps which could have been saved from previous test runs of the workflow
+                    # this filter only keeps the most recent ones
+                    if num_images < len(extractions[-1]["images"]):
+                        _len = len(extractions[-1]["images"])
+                        extractions[-1]["images"] = extractions[-1]["images"][
+                            (_len - num_images) :
                         ]
-                        .RadionuclideCodeSequence[0]
-                        .CodeMeaning,
-                        headers.energy_window_peak_lower: energy_windows[0][0],
-                        headers.energy_window_peak_upper: energy_windows[0][1],
-                        headers.energy_window_scatter_lower: energy_windows[1][0],
-                        headers.energy_window_scatter_upper: energy_windows[1][1],
-                        headers.detector_count: len(
-                            projection_image.DetectorInformationSequence
-                        ),
-                        headers.collimator_type: projection_image.DetectorInformationSequence[
-                            0
-                        ].CollimatorType,
-                        headers.rotation_start: float(
-                            projection_image.RotationInformationSequence[0].StartAngle
-                        ),
-                        headers.rotation_step: float(
-                            projection_image.RotationInformationSequence[0].AngularStep
-                        ),
-                        headers.rotation_scan_arc: float(
-                            projection_image.RotationInformationSequence[0].ScanArc
-                        ),
-                        headers.file_projection: _filename_projection,
-                        headers.file_recon_ac: _filename_recon_ac,
-                        headers.file_recon_no_ac: _filename_recon_no_ac,
-                        headers.file_mu_map: _filename_mu_map,
-                    }
-                    _id += 1
-
-                    row = pd.DataFrame(row, index=[0])
-                    data = pd.concat((data, row), ignore_index=True)
+
+                    for j in range(num_images):
+                        _recon_images = filter(
+                            lambda extrac: extrac["reconstruction"], extractions
+                        )
+                        recon_images = list(
+                            map(lambda extrac: extrac["images"][j], _recon_images)
+                        )
+
+                        # extract date times and assert that they are equal for all reconstruction images
+                        datetimes = list(map(parse_series_time, recon_images))
+                        _datetimes = sorted(datetimes, reverse=True)
+                        _datetimes_delta = list(
+                            map(lambda dt: _datetimes[0] - dt, _datetimes)
+                        )
+                        # note: somehow the images receive slightly different timestamps, maybe this depends on time to save and computation time
+                        # thus, a 10 minute interval is allowed here
+                        _equal = all(
+                            map(lambda dt: dt < timedelta(minutes=10), _datetimes_delta)
+                        )
+                        assert (
+                            _equal
+                        ), f"Not all dates and times of the reconstructions are equal: {datetimes}"
+
+                        # extract pixel spacings and assert that they are equal for all reconstruction images
+                        _map_lists = map(
+                            lambda image: [*image.PixelSpacing, image.SliceThickness],
+                            recon_images,
+                        )
+                        _map_lists = map(
+                            lambda pixel_spacing: list(map(float, pixel_spacing)),
+                            _map_lists,
+                        )
+                        _map_ndarrays = map(
+                            lambda pixel_spacing: np.array(pixel_spacing), _map_lists
+                        )
+                        pixel_spacings = list(_map_ndarrays)
+                        _equal = all(
+                            map(
+                                lambda pixel_spacing: (
+                                    pixel_spacing == pixel_spacings[0]
+                                ).all(),
+                                pixel_spacings,
+                            )
+                        )
+                        assert (
+                            _equal
+                        ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
+                        pixel_spacing = pixel_spacings[0]
+
+                        # use the shape with the fewest slices, all other images will be aligned to that
+                        _map_lists = map(
+                            lambda image: [
+                                image.Rows,
+                                image.Columns,
+                                image.NumberOfSlices,
+                            ],
+                            recon_images,
+                        )
+                        _map_lists = map(
+                            lambda shape: list(map(int, shape)), _map_lists
+                        )
+                        _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
+                        shapes = list(_map_ndarrays)
+                        shapes.sort(key=lambda shape: shape[2])
+                        shape = shapes[0]
+
+                        projection_image = extractions[0]["images"][j]
+                        # extract and sort energy windows
+                        energy_windows = (
+                            projection_image.EnergyWindowInformationSequence
+                        )
+                        energy_windows = map(
+                            lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
+                        )
+                        energy_windows = map(
+                            lambda ew: (
+                                float(ew.EnergyWindowLowerLimit),
+                                float(ew.EnergyWindowUpperLimit),
+                            ),
+                            energy_windows,
+                        )
+                        energy_windows = list(energy_windows)
+                        energy_windows.sort(key=lambda ew: ew[0], reverse=True)
+
+                        row = {
+                            headers.id: _id,
+                            headers.patient_id: projection_image.PatientID,
+                            headers.age: parse_age(projection_image.PatientAge),
+                            headers.weight: float(projection_image.PatientWeight),
+                            headers.size: float(projection_image.PatientSize),
+                            headers.protocol: protocol.name,
+                            headers.datetime_acquisition: parse_series_time(
+                                projection_image
+                            ),
+                            headers.datetime_reconstruction: datetimes[0],
+                            headers.pixel_spacing_x: pixel_spacing[0],
+                            headers.pixel_spacing_y: pixel_spacing[1],
+                            headers.pixel_spacing_z: pixel_spacing[2],
+                            headers.shape_x: shape[0],
+                            headers.shape_y: shape[1],
+                            headers.shape_z: shape[2],
+                            headers.radiopharmaceutical: projection_image.RadiopharmaceuticalInformationSequence[
+                                0
+                            ].Radiopharmaceutical,
+                            headers.radionuclide_dose: projection_image.RadiopharmaceuticalInformationSequence[
+                                0
+                            ].RadionuclideTotalDose,
+                            headers.radionuclide_code: projection_image.RadiopharmaceuticalInformationSequence[
+                                0
+                            ]
+                            .RadionuclideCodeSequence[0]
+                            .CodeValue,
+                            headers.radionuclide_meaning: projection_image.RadiopharmaceuticalInformationSequence[
+                                0
+                            ]
+                            .RadionuclideCodeSequence[0]
+                            .CodeMeaning,
+                            headers.energy_window_peak_lower: energy_windows[0][0],
+                            headers.energy_window_peak_upper: energy_windows[0][1],
+                            headers.energy_window_scatter_lower: energy_windows[1][0],
+                            headers.energy_window_scatter_upper: energy_windows[1][1],
+                            headers.detector_count: len(
+                                projection_image.DetectorInformationSequence
+                            ),
+                            headers.collimator_type: projection_image.DetectorInformationSequence[
+                                0
+                            ].CollimatorType,
+                            headers.rotation_start: float(
+                                projection_image.RotationInformationSequence[
+                                    0
+                                ].StartAngle
+                            ),
+                            headers.rotation_step: float(
+                                projection_image.RotationInformationSequence[
+                                    0
+                                ].AngularStep
+                            ),
+                            headers.rotation_scan_arc: float(
+                                projection_image.RotationInformationSequence[0].ScanArc
+                            ),
+                        }
+
+                        _filename_base = f"{_id:04d}-{protocol.name.lower()}"
+                        _ext = "dcm"
+                        _images = list(
+                            map(lambda extrac: extrac["images"][j], extractions)
+                        )
+                        for _image, extrac in zip(_images, extractions):
+                            image = pydicom.dcmread(_image.filename)
+                            filename = f"{_filename_base}-{extrac['prefix']}.{_ext}"
+                            pydicom.dcmwrite(
+                                os.path.join(args.images_dir, filename), image
+                            )
+                            row[extrac["header"]] = filename
+
+                        _id += 1
+                        row = pd.DataFrame(row, index=[0])
+                        data = pd.concat((data, row), ignore_index=True)
 
         data.to_csv(args.meta_csv, index=False)
     except Exception as e:
-- 
GitLab