""" Utility script to prepare raw data for further processing. It is highly dependant on the way the raw DICOM data were exported. The script creates a folder containing the file `meta.csv` and another folder containing the DICOM images sorted by an id for each study. """ import argparse from datetime import datetime, timedelta from enum import Enum import os from typing import List, Dict, Callable import numpy as np import pandas as pd import pydicom from mu_map.file.dicom import DICOM, DICOMTime, parse_age from mu_map.logging import add_logging_args, get_logger_by_args STUDY_DESCRIPTION = "µ-map_study" headers = argparse.Namespace() headers.id = "id" headers.patient_id = "patient_id" headers.age = "age" headers.sex = "sex" headers.weight = "weight" headers.size = "size" headers.patient_position = "patient_position" headers.protocol = "protocol" headers.datetime_acquisition = "datetime_acquisition" headers.datetime_reconstruction = "datetime_reconstruction" headers.pixel_spacing_x = "pixel_spacing_x" headers.pixel_spacing_y = "pixel_spacing_y" headers.pixel_spacing_z = "pixel_spacing_z" headers.shape_x = "shape_x" headers.shape_y = "shape_y" headers.shape_z = "shape_z" headers.radiopharmaceutical = "radiopharmaceutical" headers.radionuclide_dose = "radionuclide_dose" headers.radionuclide_code = "radionuclide_code" headers.radionuclide_meaning = "radionuclide_meaning" headers.energy_window_peak_lower = "energy_window_peak_lower" headers.energy_window_peak_upper = "energy_window_peak_upper" headers.energy_window_scatter_lower = "energy_window_scatter_lower" headers.energy_window_scatter_upper = "energy_window_scatter_upper" headers.detector_count = "detector_count" headers.collimator_type = "collimator_type" headers.rotation_start = "rotation_start" headers.rotation_step = "rotation_step" headers.rotation_scan_arc = "rotation_scan_arc" headers.file_projection = "file_projection" headers.file_recon_ac_sc = "file_recon_ac_sc" headers.file_recon_nac_sc = "file_recon_nac_sc" headers.file_recon_ac_nsc = "file_recon_ac_nsc" headers.file_recon_nac_nsc = "file_recon_nac_nsc" headers.file_mu_map = "file_mu_map" def get_protocol(projection: DICOM) -> str: """ Get the protocol (stress, rest) of a projection image by checking if it is part of the series description. Parameters ---------- projection: DICOM DICOM image of the projection Returns ------- str the protocol as a string (Stress or Rest) """ if "stress" in projection.SeriesDescription.lower(): return "Stress" if "rest" in projection.SeriesDescription.lower(): return "Rest" raise ValueError(f"Unkown protocol in projection {projection.SeriesDescription}") def find_projections( dicom_images: List[DICOM], ) -> List[DICOM]: """ Find a projections in a list of DICOM images belonging to a study. Parameters ---------- dicom_images: list of DICOM DICOM images of a study Returns ------- list of DICOM all projection images in the input list """ _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images) dicom_images = [] for dicom_image in _filter: # filter for allows series descriptions if dicom_image.SeriesDescription not in ["Stress", "Rest", "Stress_2"]: logger.warning( f"Skip projection with unknown protocol [{dicom_image.SeriesDescription}]" ) continue dicom_images.append(dicom_image) if len(dicom_images) == 0: raise ValueError(f"No projections available") return dicom_images def is_recon_type( scatter_corrected: bool, attenuation_corrected: bool ) -> Callable[[DICOM], bool]: """ Get a filter function for reconstructions that are (non-)scatter and/or (non-)attenuation corrected. Parameters ---------- scatter_corrected: bool if the filter should only return true for scatter corrected reconstructions attenuation_corrected: bool if the filter should only return true for attenuation corrected reconstructions Returns ------- Callable[DICOM, bool] a filter function that returns true if the DICOM image has the specified corrections """ if scatter_corrected and attenuation_corrected: filter_str = " SC - AC " elif not scatter_corrected and attenuation_corrected: filter_str = " NoSC - AC " elif scatter_corrected and not attenuation_corrected: filter_str = " SC - NoAC " elif not scatter_corrected and not attenuation_corrected: filter_str = " NoSC - NoAC " return lambda dicom: filter_str in dicom.SeriesDescription def find_reconstruction( dicom_images: List[DICOM], projection: DICOM, scatter_corrected: bool, attenuation_corrected: bool, ) -> List[DICOM]: """ Find all reconstructions in a list of DICOM images of a study belonging to a projection. Parameters ---------- dicom_images: lost of DICOM DICOM images belonging to the study projection: DICOM the DICOM image of the projection the reconstructions belong to scatter_corrected: bool if it should be searched for a scatter corrected reconstruction attenuation_corrected: bool if it should be searched fo a attenuation corrected reconstruction Returns ------- DICOM the according reconstruction """ protocol = get_protocol(projection) _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol in image.SeriesDescription, _filter) _filter = filter( lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter ) _filter = filter( lambda image: "CT" not in image.SeriesDescription, _filter ) # remove µ-maps _filter = filter(is_recon_type(scatter_corrected, attenuation_corrected), _filter) _filter = filter( lambda image: DICOMTime.Acquisition.to_datetime(image) == DICOMTime.Acquisition.to_datetime(projection), _filter, ) dicom_images = list(_filter) if len(dicom_images) == 0: raise ValueError( f"No reconstruction with SC={scatter_corrected}, AC={attenuation_corrected} available" ) # sort oldest to be first if len(dicom_images) > 1: logger.warning( f"Multiple reconstructions ({len(dicom_images)}) with SC={scatter_corrected}, AC={attenuation_corrected} for projection {projection.SeriesDescription} of patient {projection.PatientID}" ) dicom_images.sort( key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True ) return dicom_images[0] def find_attenuation_map( dicom_images: List[DICOM], projection: DICOM, reconstructions: List[DICOM], max_time_diff: int = 30, ) -> DICOM: """ Find an attenuation map in a list of DICOM images of a study belonging to a projection and reconstructions. Parameters ---------- dicom_images: list of DICOM the list of DICOM images belonging to the study projection: DICOM the DICOM image of the projection reconstructions: DICOM DICOM images of reconstructions belonging to the projection max_time_diff: int, optional filter out DICOM files which differ more than this value in series time to the reconstructions Returns ------- DICOM the according attenuation map """ protocol = get_protocol(projection) recon_times = list( map(lambda recon: DICOMTime.Series.to_datetime(recon), reconstructions) ) _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol in image.SeriesDescription, _filter) _filter = filter( lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter ) _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter) _filter = filter( lambda image: any( map( lambda recon_time: ( DICOMTime.Series.to_datetime(image) - recon_time ).seconds < max_time_diff, recon_times, ) ), _filter, ) dicom_images = list(_filter) if len(dicom_images) == 0: raise ValueError(f"No Attenuation map available") # sort oldest to be first if len(dicom_images) > 1: logger.warning( f"Multiple attenuation maps ({len(dicom_images)}) for projection {projection.SeriesDescription} of patient {projection.PatientID}" ) dicom_images.sort( key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True ) return dicom_images[0] def get_relevant_images(patient: DICOM, dicom_dir: str) -> List[DICOM]: """ Get all relevant images of a patient. Parameters ---------- patient: DICOM DICOM dataset of a patient dicom_dir: str the directory of the DICOM files Returns ------- list of DICOM all relevant DICOM images """ # get all myocardial scintigraphy studies studies = list( filter( lambda child: child.DirectoryRecordType == "STUDY", patient.children, ) ) # extract all DICOM images dicom_images = [] for study in studies: series = list( filter( lambda child: child.DirectoryRecordType == "SERIES", study.children, ) ) for _series in series: images = list( filter( lambda child: child.DirectoryRecordType == "IMAGE", _series.children, ) ) # all SPECT data is stored as a single 3D array which means that it is a series with a single image # this is not the case for CTs, which are skipped here if len(images) != 1: continue images = list( map( lambda image: pydicom.dcmread( os.path.join( dicom_dir_by_patient[patient.PatientID], *image.ReferencedFileID, ), stop_before_pixels=True, ), images, ) ) if len(images) == 0: continue dicom_images.append(images[0]) return dicom_images if __name__ == "__main__": parser = argparse.ArgumentParser( description="Prepare a dataset from DICOM directories", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dicom_dirs", type=str, nargs="+", help="paths to DICOMDIR files or directories containing one of them", ) parser.add_argument( "--dataset_dir", type=str, required=True, help="directory where images, meta-information and the logs are stored", ) parser.add_argument( "--images_dir", type=str, default="images", help="sub-directory of --dataset_dir where images are stored", ) parser.add_argument( "--meta_csv", type=str, default="meta.csv", help="CSV file under --dataset_dir where meta-information is stored", ) add_logging_args( parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"} ) args = parser.parse_args() args.dicom_dirs = [ (os.path.dirname(_file) if os.path.isfile(_file) else _file) for _file in args.dicom_dirs ] args.images_dir = os.path.join(args.dataset_dir, args.images_dir) args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv) args.logfile = os.path.join(args.dataset_dir, args.logfile) if not os.path.exists(args.dataset_dir): os.mkdir(args.dataset_dir) if not os.path.exists(args.images_dir): os.mkdir(args.images_dir) global logger logger = get_logger_by_args(args) try: patients = [] dicom_dir_by_patient: Dict[str, str] = {} for dicom_dir in args.dicom_dirs: dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR")) for patient in dataset.patient_records: assert ( patient.PatientID not in dicom_dir_by_patient ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})" dicom_dir_by_patient[patient.PatientID] = dicom_dir patients.append(patient) _id = 1 if os.path.exists(args.meta_csv): data = pd.read_csv(args.meta_csv) _id = int(data[headers.id].max()) else: data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()])) for i, patient in enumerate(patients, start=1): logger.debug( f"Process patient {str(i):>3}/{len(patients)} - {patient.PatientName.given_name}, {patient.PatientName.family_name}:" ) dicom_images = get_relevant_images( patient, dicom_dir_by_patient[patient.PatientID] ) projections = find_projections(dicom_images) logger.debug( f"- Found {len(projections)}: projections with protocols {list(map(lambda p: get_protocol(p), projections))}" ) for projection in projections: protocol = get_protocol(projection) reconstructions = [] recon_headers = [ headers.file_recon_nac_nsc, headers.file_recon_nac_sc, headers.file_recon_ac_nsc, headers.file_recon_ac_sc, ] recon_postfixes = [ "recon_nac_nsc", "recon_nac_sc", "recon_ac_nsc", "recon_ac_sc", ] for ac, sc in [ (False, False), (False, True), (True, False), (True, True), ]: recon = find_reconstruction( dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac, ) reconstructions.append(recon) mu_map = find_attenuation_map(dicom_images, projection, reconstructions) # extract pixel spacings and assert that they are equal for all reconstruction images _map_lists = map( lambda image: [*image.PixelSpacing, image.SliceThickness], [*reconstructions, mu_map], ) _map_lists = map( lambda pixel_spacing: list(map(float, pixel_spacing)), _map_lists, ) _map_ndarrays = map( lambda pixel_spacing: np.array(pixel_spacing), _map_lists ) pixel_spacings = list(_map_ndarrays) _equal = all( map( lambda pixel_spacing: ( pixel_spacing == pixel_spacings[0] ).all(), pixel_spacings, ) ) assert ( _equal ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}" pixel_spacing = pixel_spacings[0] # use the shape with the fewest slices, all other images will be aligned to that _map_lists = map( lambda image: [ image.Rows, image.Columns, image.NumberOfSlices, ], [*reconstructions, mu_map], ) _map_lists = map(lambda shape: list(map(int, shape)), _map_lists) _map_ndarrays = map(lambda shape: np.array(shape), _map_lists) shapes = list(_map_ndarrays) shapes.sort(key=lambda shape: shape[2]) shape = shapes[0] # extract and sort energy windows energy_windows = projection.EnergyWindowInformationSequence energy_windows = map( lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows ) energy_windows = map( lambda ew: ( float(ew.EnergyWindowLowerLimit), float(ew.EnergyWindowUpperLimit), ), energy_windows, ) energy_windows = list(energy_windows) energy_windows.sort(key=lambda ew: ew[0], reverse=True) row = { headers.id: _id, headers.patient_id: projection.PatientID, headers.age: parse_age(projection.PatientAge), headers.sex: projection.PatientSex, headers.weight: float(projection.PatientWeight), headers.size: float(projection.PatientSize), headers.patient_position: projection.PatientPosition, headers.protocol: protocol, headers.datetime_acquisition: DICOMTime.Series.to_datetime( projection ), headers.datetime_reconstruction: DICOMTime.Series.to_datetime( reconstructions[0] ), headers.pixel_spacing_x: pixel_spacing[0], headers.pixel_spacing_y: pixel_spacing[1], headers.pixel_spacing_z: pixel_spacing[2], headers.shape_x: shape[0], headers.shape_y: shape[1], headers.shape_z: shape[2], headers.radiopharmaceutical: projection.RadiopharmaceuticalInformationSequence[ 0 ].Radiopharmaceutical, headers.radionuclide_dose: projection.RadiopharmaceuticalInformationSequence[ 0 ].RadionuclideTotalDose, headers.radionuclide_code: projection.RadiopharmaceuticalInformationSequence[ 0 ] .RadionuclideCodeSequence[0] .CodeValue, headers.radionuclide_meaning: projection.RadiopharmaceuticalInformationSequence[ 0 ] .RadionuclideCodeSequence[0] .CodeMeaning, headers.energy_window_peak_lower: energy_windows[0][0], headers.energy_window_peak_upper: energy_windows[0][1], headers.energy_window_scatter_lower: energy_windows[1][0], headers.energy_window_scatter_upper: energy_windows[1][1], headers.detector_count: len(projection.DetectorInformationSequence), headers.collimator_type: projection.DetectorInformationSequence[ 0 ].CollimatorType, headers.rotation_start: float( projection.RotationInformationSequence[0].StartAngle ), headers.rotation_step: float( projection.RotationInformationSequence[0].AngularStep ), headers.rotation_scan_arc: float( projection.RotationInformationSequence[0].ScanArc ), } _filename_base = f"{_id:04d}-{protocol.lower()}" _ext = "dcm" img_headers = [ headers.file_projection, *recon_headers, headers.file_mu_map, ] img_postfixes = ["projection", *recon_postfixes, "mu_map"] images = [projection, *reconstructions, mu_map] for header, image, postfix in zip(img_headers, images, img_postfixes): _image = pydicom.dcmread(image.filename) filename = f"{_filename_base}-{postfix}.{_ext}" pydicom.dcmwrite(os.path.join(args.images_dir, filename), _image) row[header] = filename _id += 1 row = pd.DataFrame(row, index=[0]) data = pd.concat((data, row), ignore_index=True) data.to_csv(args.meta_csv, index=False) except Exception as e: logger.error(e) raise e