import argparse from datetime import datetime, timedelta from enum import Enum import os from typing import List, Dict import numpy as np import pandas as pd import pydicom from mu_map.logging import add_logging_args, get_logger_by_args STUDY_DESCRIPTION = "µ-map_study" class MyocardialProtocol(Enum): Stress = 1 Rest = 2 headers = argparse.Namespace() headers.id = "id" headers.patient_id = "patient_id" headers.age = "age" headers.weight = "weight" headers.size = "size" headers.protocol = "protocol" headers.datetime_acquisition = "datetime_acquisition" headers.datetime_reconstruction = "datetime_reconstruction" headers.pixel_spacing_x = "pixel_spacing_x" headers.pixel_spacing_y = "pixel_spacing_y" headers.pixel_spacing_z = "pixel_spacing_z" headers.shape_x = "shape_x" headers.shape_y = "shape_y" headers.shape_z = "shape_z" headers.radiopharmaceutical = "radiopharmaceutical" headers.radionuclide_dose = "radionuclide_dose" headers.radionuclide_code = "radionuclide_code" headers.radionuclide_meaning = "radionuclide_meaning" headers.energy_window_peak_lower = "energy_window_peak_lower" headers.energy_window_peak_upper = "energy_window_peak_upper" headers.energy_window_scatter_lower = "energy_window_scatter_lower" headers.energy_window_scatter_upper = "energy_window_scatter_upper" headers.detector_count = "detector_count" headers.collimator_type = "collimator_type" headers.rotation_start = "rotation_start" headers.rotation_step = "rotation_step" headers.rotation_scan_arc = "rotation_scan_arc" headers.file_projection = "file_projection" headers.file_recon_ac_sc = "file_recon_ac_sc" headers.file_recon_nac_sc = "file_recon_nac_sc" headers.file_recon_ac_nsc = "file_recon_ac_nsc" headers.file_recon_nac_nsc = "file_recon_nac_nsc" headers.file_mu_map = "file_mu_map" def parse_series_time(dicom_image: pydicom.dataset.FileDataset) -> datetime: """ Parse the date and time of a DICOM series object into a datetime object. :param dicom_image: the dicom file to parse the series date and time from :return: an according python datetime object. """ _date = dicom_image.SeriesDate _time = dicom_image.SeriesTime return datetime( year=int(_date[0:4]), month=int(_date[4:6]), day=int(_date[6:8]), hour=int(_time[0:2]), minute=int(_time[2:4]), second=int(_time[4:6]), microsecond=int(_time.split(".")[1]), ) def parse_age(patient_age: str) -> int: """ Parse and age string as defined in the DICOM standard into an integer representing the age in years. :param patient_age: age string as defined in the DICOM standard :return: the age in years as a number """ assert ( type(patient_age) == str ), f"patient age needs to be a string and not {type(patient_age)}" assert ( len(patient_age) == 4 ), f"patient age [{patient_age}] has to be four characters long" _num, _format = patient_age[:3], patient_age[3] assert ( _format == "Y" ), f"currently, only patient ages in years [Y] is supported, not [{_format}]" return int(_num) def get_projections( dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol ) -> pydicom.dataset.FileDataset: """ Extract the SPECT projection from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol. :param dicom_images: list of DICOM images of a study :param protocol: the protocol for which the projection images should be extracted :return: the extracted DICOM image """ _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter) dicom_images = list(_filter) if len(dicom_images) == 0: raise ValueError(f"No projection for protocol {protocol.name} available") return dicom_images def get_reconstructions( dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol, scatter_corrected: bool, attenuation_corrected: bool, ) -> pydicom.dataset.FileDataset: """ Extract a SPECT reconstruction from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol. The corrected flag can be used to either extract an attenuation corrected or a non-attenuation corrected image. If there are multiple images, they are sorted by acquisition date and the newest is returned. :param dicom_images: list of DICOM images of a study :param protocol: the protocol for which the projection images should be extracted :param attenuation_corrected: extract an attenuation or non-attenuation corrected image :param scatter_corrected: extract the image to which scatter correction was applied :return: the extracted DICOM image """ _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter) _filter = filter( lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter ) if scatter_corrected and attenuation_corrected: filter_str = " SC - AC " elif not scatter_corrected and attenuation_corrected: filter_str = " NoSC - AC " elif scatter_corrected and not attenuation_corrected: filter_str = " SC - NoAC " elif not scatter_corrected and not attenuation_corrected: filter_str = " NoSC - NoAC " _filter = filter(lambda image: filter_str in image.SeriesDescription, _filter) # for SPECT reconstructions created in clinical studies this value exists and is set to 'APEX_TO_BASE' # for the reconstructions with attenuation maps it does not exist # _filter = filter( # lambda image: not hasattr(image, "SliceProgressionDirection"), _filter # ) dicom_images = list(_filter) if len(dicom_images) == 0: raise ValueError( f"'{filter_str}' Reconstruction for protocol {protocol.name} is not available" ) return dicom_images def get_attenuation_maps( dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol ) -> pydicom.dataset.FileDataset: """ Extract an attenuation map from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol. If there are multiple attenuation maps, they are sorted by acquisition date and the newest is returned. :param dicom_images: list of DICOM images of a study :param protocol: the protocol for which the projection images should be extracted :return: the extracted DICOM image """ _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol.name in image.SeriesDescription, _filter) _filter = filter( lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter ) _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter) dicom_images = list(_filter) if len(dicom_images) == 0: raise ValueError( f"Attenuation map for protocol {protocol.name} is not available" ) return dicom_images if __name__ == "__main__": parser = argparse.ArgumentParser( description="Prepare a dataset from DICOM directories", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dicom_dirs", type=str, nargs="+", help="paths to DICOMDIR files or directories containing one of them", ) parser.add_argument( "--dataset_dir", type=str, required=True, help="directory where images, meta-information and the logs are stored", ) parser.add_argument( "--images_dir", type=str, default="images", help="sub-directory of --dataset_dir where images are stored", ) parser.add_argument( "--meta_csv", type=str, default="meta.csv", help="CSV file under --dataset_dir where meta-information is stored", ) add_logging_args( parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"} ) args = parser.parse_args() args.dicom_dirs = [ (os.path.dirname(_file) if os.path.isfile(_file) else _file) for _file in args.dicom_dirs ] args.images_dir = os.path.join(args.dataset_dir, args.images_dir) args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv) args.logfile = os.path.join(args.dataset_dir, args.logfile) if not os.path.exists(args.dataset_dir): os.mkdir(args.dataset_dir) if not os.path.exists(args.images_dir): os.mkdir(args.images_dir) global logger logger = get_logger_by_args(args) try: patients = [] dicom_dir_by_patient: Dict[str, str] = {} for dicom_dir in args.dicom_dirs: dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR")) for patient in dataset.patient_records: assert ( patient.PatientID not in dicom_dir_by_patient ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})" dicom_dir_by_patient[patient.PatientID] = dicom_dir patients.append(patient) _id = 1 if os.path.exists(args.meta_csv): data = pd.read_csv(args.meta_csv) _id = int(data[headers.id].max()) else: data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()])) for i, patient in enumerate(patients, start=1): logger.debug(f"Process patient {str(i):>3}/{len(patients)}:") # get all myocardial scintigraphy studies studies = list( filter( lambda child: child.DirectoryRecordType == "STUDY", # and child.StudyDescription == "Myokardszintigraphie", # filter is disabled because there is a study without this description and only such studies are exported anyway patient.children, ) ) # extract all dicom images dicom_images = [] for study in studies: series = list( filter( lambda child: child.DirectoryRecordType == "SERIES", study.children, ) ) for _series in series: images = list( filter( lambda child: child.DirectoryRecordType == "IMAGE", _series.children, ) ) # all SPECT data is stored as a single 3D array which means that it is a series with a single image # this is not the case for CTs, which are skipped here if len(images) != 1: continue images = list( map( lambda image: pydicom.dcmread( os.path.join( dicom_dir_by_patient[patient.PatientID], *image.ReferencedFileID, ), stop_before_pixels=True, ), images, ) ) if len(images) == 0: continue dicom_images.append(images[0]) for protocol in MyocardialProtocol: if ( len( data[ (data[headers.patient_id] == int(patient.PatientID)) & (data[headers.protocol] == protocol.name) ] ) > 0 ): logger.info( f"Skip {patient.PatientID}:{protocol.name} since it is already contained in the dataset" ) continue extractions = [ { "function": get_projections, "kwargs": {}, "reconstruction": False, "prefix": "projection", "header": headers.file_projection, }, { "function": get_reconstructions, "kwargs": { "attenuation_corrected": True, "scatter_corrected": True, }, "reconstruction": True, "prefix": "recon_ac_sc", "header": headers.file_recon_ac_sc, }, { "function": get_reconstructions, "kwargs": { "attenuation_corrected": True, "scatter_corrected": False, }, "reconstruction": True, "prefix": "recon_ac_nsc", "header": headers.file_recon_ac_nsc, }, { "function": get_reconstructions, "kwargs": { "attenuation_corrected": False, "scatter_corrected": True, }, "reconstruction": True, "prefix": "recon_nac_sc", "header": headers.file_recon_nac_sc, }, { "function": get_reconstructions, "kwargs": { "attenuation_corrected": False, "scatter_corrected": False, }, "reconstruction": True, "prefix": "recon_nac_nsc", "header": headers.file_recon_nac_nsc, }, { "function": get_attenuation_maps, "kwargs": {}, "reconstruction": True, "prefix": "mu_map", "header": headers.file_mu_map, }, ] try: for extrac in extractions: _images = extrac["function"]( dicom_images, protocol=protocol, **extrac["kwargs"] ) _images.sort(key=parse_series_time) extrac["images"] = _images except ValueError as e: logger.info( f"Skip {patient.PatientID}:{protocol.name} because {e}" ) continue num_images = min( list(map(lambda extrac: len(extrac["images"]), extractions)) ) # ATTENTION: this is a special filter for mu maps which could have been saved from previous test runs of the workflow # this filter only keeps the most recent ones if num_images < len(extractions[-1]["images"]): _len = len(extractions[-1]["images"]) extractions[-1]["images"] = extractions[-1]["images"][ (_len - num_images) : ] for j in range(num_images): _recon_images = filter( lambda extrac: extrac["reconstruction"], extractions ) recon_images = list( map(lambda extrac: extrac["images"][j], _recon_images) ) # extract date times and assert that they are equal for all reconstruction images datetimes = list(map(parse_series_time, recon_images)) _datetimes = sorted(datetimes, reverse=True) _datetimes_delta = list( map(lambda dt: _datetimes[0] - dt, _datetimes) ) # note: somehow the images receive slightly different timestamps, maybe this depends on time to save and computation time # thus, a 10 minute interval is allowed here _equal = all( map(lambda dt: dt < timedelta(minutes=10), _datetimes_delta) ) assert ( _equal ), f"Not all dates and times of the reconstructions are equal: {datetimes}" # extract pixel spacings and assert that they are equal for all reconstruction images _map_lists = map( lambda image: [*image.PixelSpacing, image.SliceThickness], recon_images, ) _map_lists = map( lambda pixel_spacing: list(map(float, pixel_spacing)), _map_lists, ) _map_ndarrays = map( lambda pixel_spacing: np.array(pixel_spacing), _map_lists ) pixel_spacings = list(_map_ndarrays) _equal = all( map( lambda pixel_spacing: ( pixel_spacing == pixel_spacings[0] ).all(), pixel_spacings, ) ) assert ( _equal ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}" pixel_spacing = pixel_spacings[0] # use the shape with the fewest slices, all other images will be aligned to that _map_lists = map( lambda image: [ image.Rows, image.Columns, image.NumberOfSlices, ], recon_images, ) _map_lists = map( lambda shape: list(map(int, shape)), _map_lists ) _map_ndarrays = map(lambda shape: np.array(shape), _map_lists) shapes = list(_map_ndarrays) shapes.sort(key=lambda shape: shape[2]) shape = shapes[0] projection_image = extractions[0]["images"][j] # extract and sort energy windows energy_windows = ( projection_image.EnergyWindowInformationSequence ) energy_windows = map( lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows ) energy_windows = map( lambda ew: ( float(ew.EnergyWindowLowerLimit), float(ew.EnergyWindowUpperLimit), ), energy_windows, ) energy_windows = list(energy_windows) energy_windows.sort(key=lambda ew: ew[0], reverse=True) row = { headers.id: _id, headers.patient_id: projection_image.PatientID, headers.age: parse_age(projection_image.PatientAge), headers.weight: float(projection_image.PatientWeight), headers.size: float(projection_image.PatientSize), headers.protocol: protocol.name, headers.datetime_acquisition: parse_series_time( projection_image ), headers.datetime_reconstruction: datetimes[0], headers.pixel_spacing_x: pixel_spacing[0], headers.pixel_spacing_y: pixel_spacing[1], headers.pixel_spacing_z: pixel_spacing[2], headers.shape_x: shape[0], headers.shape_y: shape[1], headers.shape_z: shape[2], headers.radiopharmaceutical: projection_image.RadiopharmaceuticalInformationSequence[ 0 ].Radiopharmaceutical, headers.radionuclide_dose: projection_image.RadiopharmaceuticalInformationSequence[ 0 ].RadionuclideTotalDose, headers.radionuclide_code: projection_image.RadiopharmaceuticalInformationSequence[ 0 ] .RadionuclideCodeSequence[0] .CodeValue, headers.radionuclide_meaning: projection_image.RadiopharmaceuticalInformationSequence[ 0 ] .RadionuclideCodeSequence[0] .CodeMeaning, headers.energy_window_peak_lower: energy_windows[0][0], headers.energy_window_peak_upper: energy_windows[0][1], headers.energy_window_scatter_lower: energy_windows[1][0], headers.energy_window_scatter_upper: energy_windows[1][1], headers.detector_count: len( projection_image.DetectorInformationSequence ), headers.collimator_type: projection_image.DetectorInformationSequence[ 0 ].CollimatorType, headers.rotation_start: float( projection_image.RotationInformationSequence[ 0 ].StartAngle ), headers.rotation_step: float( projection_image.RotationInformationSequence[ 0 ].AngularStep ), headers.rotation_scan_arc: float( projection_image.RotationInformationSequence[0].ScanArc ), } _filename_base = f"{_id:04d}-{protocol.name.lower()}" _ext = "dcm" _images = list( map(lambda extrac: extrac["images"][j], extractions) ) for _image, extrac in zip(_images, extractions): image = pydicom.dcmread(_image.filename) filename = f"{_filename_base}-{extrac['prefix']}.{_ext}" pydicom.dcmwrite( os.path.join(args.images_dir, filename), image ) row[extrac["header"]] = filename _id += 1 row = pd.DataFrame(row, index=[0]) data = pd.concat((data, row), ignore_index=True) data.to_csv(args.meta_csv, index=False) except Exception as e: logger.error(e) raise e