From a800094d38af1f4821368f2303b0a74817839a61 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 25 Aug 2022 11:15:15 +0200
Subject: [PATCH] add logging and descriptions to prepare

---
 mu_map/data/prepare.py | 254 +++++++++++++++++++++++++++++------------
 1 file changed, 181 insertions(+), 73 deletions(-)

diff --git a/mu_map/data/prepare.py b/mu_map/data/prepare.py
index a735802..a6ac8f2 100644
--- a/mu_map/data/prepare.py
+++ b/mu_map/data/prepare.py
@@ -2,12 +2,14 @@ import argparse
 from datetime import datetime, timedelta
 from enum import Enum
 import os
-from typing import List
+from typing import List, Dict
 
 import numpy as np
 import pandas as pd
 import pydicom
 
+from mu_map.logging import add_logging_args, get_logger_by_args
+
 
 class MyocardialProtocol(Enum):
     Stress = 1
@@ -47,7 +49,6 @@ headers.file_recon_no_ac = "file_recon_no_ac"
 headers.file_mu_map = "file_mu_map"
 
 
-
 def parse_series_time(dicom_image: pydicom.dataset.FileDataset) -> datetime:
     """
     Parse the date and time of a DICOM series object into a datetime object.
@@ -88,7 +89,9 @@ def parse_age(patient_age: str) -> int:
     return int(_num)
 
 
-def get_projection(dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol) -> pydicom.dataset.FileDataset:
+def get_projection(
+    dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol
+) -> pydicom.dataset.FileDataset:
     """
     Extract the SPECT projection from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
 
@@ -96,17 +99,23 @@ def get_projection(dicom_images: List[pydicom.dataset.FileDataset], protocol: My
     :param protocol: the protocol for which the projection images should be extracted
     :return: the extracted DICOM image
     """
-    dicom_images = filter(lambda image: "TOMO" in image.ImageType, dicom_images)
-    dicom_images = filter(lambda image: protocol.name in image.SeriesDescription, dicom_images)
-    dicom_images = list(dicom_images)
+    _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images)
+    _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter)
+    dicom_images = list(_filter)
 
     if len(dicom_images) != 1:
-        raise ValueError(f"No or multiple projections {len(dicom_images)} for protocol {protocol.name} available")
+        raise ValueError(
+            f"No or multiple projections {len(dicom_images)} for protocol {protocol.name} available"
+        )
 
     return dicom_images[0]
 
 
-def get_reconstruction(dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol, corrected:bool=True) -> pydicom.dataset.FileDataset:
+def get_reconstruction(
+    dicom_images: List[pydicom.dataset.FileDataset],
+    protocol: MyocardialProtocol,
+    corrected: bool = True,
+) -> pydicom.dataset.FileDataset:
     """
     Extract a SPECT reconstruction from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
     The corrected flag can be used to either extract an attenuation corrected or a non-attenuation corrected image.
@@ -117,38 +126,41 @@ def get_reconstruction(dicom_images: List[pydicom.dataset.FileDataset], protocol
     :param corrected: extract an attenuation or non-attenuation corrected image
     :return: the extracted DICOM image
     """
-    dicom_images = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
-    dicom_images = filter(lambda image: protocol.name in image.SeriesDescription, dicom_images)
+    _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
+    _filter = filter(
+        lambda image: protocol.name in image.SeriesDescription, dicom_images
+    )
 
     if corrected:
-        dicom_images = filter(
+        _filter = filter(
             lambda image: "AC" in image.SeriesDescription
             and "NoAC" not in image.SeriesDescription,
             dicom_images,
         )
-        dicom_images = list(dicom_images)
     else:
-        dicom_images = filter(
-            lambda image: "NoAC" in image.SeriesDescription, dicom_images
-        )
+        _filter = filter(lambda image: "NoAC" in image.SeriesDescription, dicom_images)
 
     # for SPECT reconstructions created in clinical studies this value exists and is set to 'APEX_TO_BASE'
     # for the reconstructions with attenuation maps it does not exist
-    dicom_images = filter(
+    _filter = filter(
         lambda image: not hasattr(image, "SliceProgressionDirection"), dicom_images
     )
 
-    dicom_images = list(dicom_images)
+    dicom_images = list(_filter)
     dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True)
 
     if len(dicom_images) == 0:
         _str = "AC" if corrected else "NoAC"
-        raise ValueError(f"{_str} Reconstruction for protocol {protocol.name} is not available")
+        raise ValueError(
+            f"{_str} Reconstruction for protocol {protocol.name} is not available"
+        )
 
     return dicom_images[0]
 
 
-def get_attenuation_map(dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol) -> pydicom.dataset.FileDataset:
+def get_attenuation_map(
+    dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol
+) -> pydicom.dataset.FileDataset:
     """
     Extract an attenuation map from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol.
     If there are multiple attenuation maps, they are sorted by acquisition date and the newest is returned.
@@ -157,14 +169,18 @@ def get_attenuation_map(dicom_images: List[pydicom.dataset.FileDataset], protoco
     :param protocol: the protocol for which the projection images should be extracted
     :return: the extracted DICOM image
     """
-    dicom_images = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
-    dicom_images = filter(lambda image: protocol.name in image.SeriesDescription, dicom_images)
-    dicom_images = filter(lambda image: "µ-map" in image.SeriesDescription, dicom_images)
-    dicom_images = list(dicom_images)
+    _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
+    _filter = filter(
+        lambda image: protocol.name in image.SeriesDescription, dicom_images
+    )
+    _filter = filter(lambda image: "µ-map" in image.SeriesDescription, dicom_images)
+    dicom_images = list(_filter)
     dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True)
 
     if len(dicom_images) == 0:
-        raise ValueError(f"Attenuation map for protocol {protocol.name} is not available")
+        raise ValueError(
+            f"Attenuation map for protocol {protocol.name} is not available"
+        )
 
     return dicom_images[0]
 
@@ -180,13 +196,51 @@ if __name__ == "__main__":
         nargs="+",
         help="paths to DICOMDIR files or directories containing one of them",
     )
-    parser.add_argument("--dataset_dir", type=str, required=True, help="")
-    parser.add_argument("--images_dir", type=str, default="images", help="")
-    parser.add_argument("--csv", type=str, default="data.csv", help="")
-    parser.add_argument("--prefix_projection", type=str, default="projection", help="")
-    parser.add_argument("--prefix_mu_map", type=str, default="mu_map", help="")
-    parser.add_argument("--prefix_recon_ac", type=str, default="recon_ac", help="")
-    parser.add_argument("--prefix_recon_no_ac", type=str, default="recon_no_ac", help="")
+    parser.add_argument(
+        "--dataset_dir",
+        type=str,
+        required=True,
+        help="directory where images, meta-information and the logs are stored",
+    )
+    parser.add_argument(
+        "--images_dir",
+        type=str,
+        default="images",
+        help="sub-directory of --dataset_dir where images are stored",
+    )
+    parser.add_argument(
+        "--meta_csv",
+        type=str,
+        default="meta.csv",
+        help="CSV file under --dataset_dir where meta-information is stored",
+    )
+    parser.add_argument(
+        "--prefix_projection",
+        type=str,
+        default="projection",
+        help="prefix used to store DICOM images of projections - format <id>-<protocol>-<prefix>.dcm",
+    )
+    parser.add_argument(
+        "--prefix_mu_map",
+        type=str,
+        default="mu_map",
+        help="prefix used to store DICOM images of attenuation maps - format <id>-<protocol>-<prefix>.dcm",
+    )
+    parser.add_argument(
+        "--prefix_recon_ac",
+        type=str,
+        default="recon_ac",
+        help="prefix used to store DICOM images of reconstructions with attenuation correction - format <id>-<protocol>-<prefix>.dcm",
+    )
+    parser.add_argument(
+        "--prefix_recon_no_ac",
+        type=str,
+        default="recon_no_ac",
+        help="prefix used to store DICOM images of reconstructions without attenuation correction - format <id>-<protocol>-<prefix>.dcm",
+    )
+    add_logging_args(
+        parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
+    )
     args = parser.parse_args()
 
     args.dicom_dirs = [
@@ -194,10 +248,20 @@ if __name__ == "__main__":
         for _file in args.dicom_dirs
     ]
     args.images_dir = os.path.join(args.dataset_dir, args.images_dir)
-    args.csv = os.path.join(args.dataset_dir, args.csv)
+    args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv)
+    args.logfile = os.path.join(args.dataset_dir, args.logfile)
+
+    if not os.path.exists(args.dataset_dir):
+        os.mkdir(args.dataset_dir)
+
+    if not os.path.exists(args.images_dir):
+        os.mkdir(args.images_dir)
+
+    global logger
+    logger = get_logger_by_args(args)
 
     patients = []
-    dicom_dir_by_patient = {}
+    dicom_dir_by_patient: Dict[str, str] = {}
     for dicom_dir in args.dicom_dirs:
         dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
         for patient in dataset.patient_records:
@@ -207,22 +271,15 @@ if __name__ == "__main__":
             dicom_dir_by_patient[patient.PatientID] = dicom_dir
             patients.append(patient)
 
-    if not os.path.exists(args.dataset_dir):
-        os.mkdir(args.dataset_dir)
-
-    if not os.path.exists(args.images_dir):
-        os.mkdir(args.images_dir)
-
     _id = 1
-    if os.path.exists(args.csv):
-        data = pd.read_csv(args.csv)
+    if os.path.exists(args.meta_csv):
+        data = pd.read_csv(args.meta_csv)
         _id = int(data[headers.id].max())
     else:
         data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))
 
-
     for i, patient in enumerate(patients, start=1):
-        print(f"Process patient {str(i):>3}/{len(patients)}:")
+        logger.debug(f"Process patient {str(i):>3}/{len(patients)}:")
 
         # get all myocardial scintigraphy studies
         studies = list(
@@ -237,12 +294,15 @@ if __name__ == "__main__":
         dicom_images = []
         for study in studies:
             series = list(
-                filter(lambda child: child.DirectoryRecordType == "SERIES", study.children)
+                filter(
+                    lambda child: child.DirectoryRecordType == "SERIES", study.children
+                )
             )
             for _series in series:
                 images = list(
                     filter(
-                        lambda child: child.DirectoryRecordType == "IMAGE", _series.children
+                        lambda child: child.DirectoryRecordType == "IMAGE",
+                        _series.children,
                     )
                 )
 
@@ -254,7 +314,10 @@ if __name__ == "__main__":
                 images = list(
                     map(
                         lambda image: pydicom.dcmread(
-                            os.path.join(dicom_dir_by_patient[patient.PatientID], *image.ReferencedFileID),
+                            os.path.join(
+                                dicom_dir_by_patient[patient.PatientID],
+                                *image.ReferencedFileID,
+                            ),
                             stop_before_pixels=True,
                         ),
                         images,
@@ -266,19 +329,34 @@ if __name__ == "__main__":
 
                 dicom_images.append(images[0])
 
-
             for protocol in MyocardialProtocol:
-                if len(data[(data[headers.patient_id] == patient.PatientID) & (data[headers.protocol] == protocol.name)]) > 0:
-                    print(f"Skip {patient.PatientID}:{protocol.name} since it is already contained in the dataset")
+                if (
+                    len(
+                        data[
+                            (data[headers.patient_id] == patient.PatientID)
+                            & (data[headers.protocol] == protocol.name)
+                        ]
+                    )
+                    > 0
+                ):
+                    logger.info(
+                        f"Skip {patient.PatientID}:{protocol.name} since it is already contained in the dataset"
+                    )
                     continue
 
                 try:
                     projection_image = get_projection(dicom_images, protocol=protocol)
-                    recon_ac = get_reconstruction(dicom_images, protocol=protocol, corrected=True)
-                    recon_noac = get_reconstruction(dicom_images, protocol=protocol, corrected=False)
-                    attenuation_map = get_attenuation_map(dicom_images, protocol=protocol)
+                    recon_ac = get_reconstruction(
+                        dicom_images, protocol=protocol, corrected=True
+                    )
+                    recon_noac = get_reconstruction(
+                        dicom_images, protocol=protocol, corrected=False
+                    )
+                    attenuation_map = get_attenuation_map(
+                        dicom_images, protocol=protocol
+                    )
                 except ValueError as e:
-                    print(f"Skip {patient.PatientID}:{protocol.name} because {e}")
+                    logger.info(f"Skip {patient.PatientID}:{protocol.name} because {e}")
                     continue
 
                 recon_images = [recon_ac, recon_noac, attenuation_map]
@@ -287,25 +365,30 @@ if __name__ == "__main__":
                 datetimes = list(map(parse_series_time, recon_images))
                 _datetimes = sorted(datetimes, reverse=True)
                 _datetimes_delta = list(map(lambda dt: _datetimes[0] - dt, _datetimes))
-                _equal = all(map(lambda dt: dt < timedelta(seconds=300), _datetimes_delta))
+                _equal = all(
+                    map(lambda dt: dt < timedelta(seconds=300), _datetimes_delta)
+                )
                 assert (
                     _equal
                 ), f"Not all dates and times of the reconstructions are equal: {datetimes}"
 
                 # extract pixel spacings and assert that they are equal for all reconstruction images
-                pixel_spacings = map(
-                    lambda image: [*image.PixelSpacing, image.SliceThickness], recon_images
+                _map_lists = map(
+                    lambda image: [*image.PixelSpacing, image.SliceThickness],
+                    recon_images,
                 )
-                pixel_spacings = map(
-                    lambda pixel_spacing: list(map(float, pixel_spacing)), pixel_spacings
+                _map_lists = map(
+                    lambda pixel_spacing: list(map(float, pixel_spacing)), _map_lists
                 )
-                pixel_spacings = map(
-                    lambda pixel_spacing: np.array(pixel_spacing), pixel_spacings
+                _map_ndarrays = map(
+                    lambda pixel_spacing: np.array(pixel_spacing), _map_lists
                 )
-                pixel_spacings = list(pixel_spacings)
+                pixel_spacings = list(_map_ndarrays)
                 _equal = all(
                     map(
-                        lambda pixel_spacing: (pixel_spacing == pixel_spacings[0]).all(),
+                        lambda pixel_spacing: (
+                            pixel_spacing == pixel_spacings[0]
+                        ).all(),
                         pixel_spacings,
                     )
                 )
@@ -315,13 +398,13 @@ if __name__ == "__main__":
                 pixel_spacing = pixel_spacings[0]
 
                 # extract shapes and assert that they are equal for all reconstruction images
-                shapes = map(
+                _map_lists = map(
                     lambda image: [image.Rows, image.Columns, image.NumberOfSlices],
                     recon_images,
                 )
-                shapes = map(lambda shape: list(map(int, shape)), shapes)
-                shapes = map(lambda shape: np.array(shape), shapes)
-                shapes = list(shapes)
+                _map_lists = map(lambda shape: list(map(int, shape)), _map_lists)
+                _map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
+                shapes = list(_map_ndarrays)
                 _equal = all(map(lambda shape: (shape == shapes[0]).all(), shapes))
                 # assert _equal, f"Not all shapes of the reconstructions are equal: {shapes}"
                 # print(shapes)
@@ -348,10 +431,34 @@ if __name__ == "__main__":
                 recon_noac = pydicom.dcmread(recon_noac.filename)
                 attenuation_map = pydicom.dcmread(attenuation_map.filename)
 
-                pydicom.dcmwrite(os.path.join(args.images_dir, f"{_id:04d}-{protocol.name.lower()}-{args.prefix_projection}.dcm"), projection_image)
-                pydicom.dcmwrite(os.path.join(args.images_dir, f"{_id:04d}-{protocol.name.lower()}-{args.prefix_recon_ac}.dcm"), recon_ac)
-                pydicom.dcmwrite(os.path.join(args.images_dir, f"{_id:04d}-{protocol.name.lower()}-{args.prefix_recon_no_ac}.dcm"), recon_noac)
-                pydicom.dcmwrite(os.path.join(args.images_dir, f"{_id:04d}-{protocol.name.lower()}-{args.prefix_mu_map}.dcm"), attenuation_map)
+                pydicom.dcmwrite(
+                    os.path.join(
+                        args.images_dir,
+                        f"{_id:04d}-{protocol.name.lower()}-{args.prefix_projection}.dcm",
+                    ),
+                    projection_image,
+                )
+                pydicom.dcmwrite(
+                    os.path.join(
+                        args.images_dir,
+                        f"{_id:04d}-{protocol.name.lower()}-{args.prefix_recon_ac}.dcm",
+                    ),
+                    recon_ac,
+                )
+                pydicom.dcmwrite(
+                    os.path.join(
+                        args.images_dir,
+                        f"{_id:04d}-{protocol.name.lower()}-{args.prefix_recon_no_ac}.dcm",
+                    ),
+                    recon_noac,
+                )
+                pydicom.dcmwrite(
+                    os.path.join(
+                        args.images_dir,
+                        f"{_id:04d}-{protocol.name.lower()}-{args.prefix_mu_map}.dcm",
+                    ),
+                    attenuation_map,
+                )
 
                 row = {
                     headers.id: _id,
@@ -388,7 +495,9 @@ if __name__ == "__main__":
                     headers.energy_window_peak_upper: energy_windows[0][1],
                     headers.energy_window_scatter_lower: energy_windows[1][0],
                     headers.energy_window_scatter_upper: energy_windows[1][1],
-                    headers.detector_count: len(projection_image.DetectorInformationSequence),
+                    headers.detector_count: len(
+                        projection_image.DetectorInformationSequence
+                    ),
                     headers.collimator_type: projection_image.DetectorInformationSequence[
                         0
                     ].CollimatorType,
@@ -410,5 +519,4 @@ if __name__ == "__main__":
                 row = pd.DataFrame(row, index=[0])
                 data = pd.concat((data, row), ignore_index=True)
 
-data.to_csv(args.csv, index=False)
-
+data.to_csv(args.meta_csv, index=False)
-- 
GitLab