import argparse from datetime import datetime, timedelta from enum import Enum import os from typing import List, Dict, Callable import numpy as np import pandas as pd import pydicom from mu_map.file.dicom import DICOMTime, parse_age from mu_map.logging import add_logging_args, get_logger_by_args STUDY_DESCRIPTION = "µ-map_study" headers = argparse.Namespace() headers.id = "id" headers.patient_id = "patient_id" headers.age = "age" headers.sex = "sex" headers.weight = "weight" headers.size = "size" headers.patient_position = "patient_position" headers.protocol = "protocol" headers.datetime_acquisition = "datetime_acquisition" headers.datetime_reconstruction = "datetime_reconstruction" headers.pixel_spacing_x = "pixel_spacing_x" headers.pixel_spacing_y = "pixel_spacing_y" headers.pixel_spacing_z = "pixel_spacing_z" headers.shape_x = "shape_x" headers.shape_y = "shape_y" headers.shape_z = "shape_z" headers.radiopharmaceutical = "radiopharmaceutical" headers.radionuclide_dose = "radionuclide_dose" headers.radionuclide_code = "radionuclide_code" headers.radionuclide_meaning = "radionuclide_meaning" headers.energy_window_peak_lower = "energy_window_peak_lower" headers.energy_window_peak_upper = "energy_window_peak_upper" headers.energy_window_scatter_lower = "energy_window_scatter_lower" headers.energy_window_scatter_upper = "energy_window_scatter_upper" headers.detector_count = "detector_count" headers.collimator_type = "collimator_type" headers.rotation_start = "rotation_start" headers.rotation_step = "rotation_step" headers.rotation_scan_arc = "rotation_scan_arc" headers.file_projection = "file_projection" headers.file_recon_ac_sc = "file_recon_ac_sc" headers.file_recon_nac_sc = "file_recon_nac_sc" headers.file_recon_ac_nsc = "file_recon_ac_nsc" headers.file_recon_nac_nsc = "file_recon_nac_nsc" headers.file_mu_map = "file_mu_map" def get_protocol(projection: pydicom.dataset.FileDataset) -> str: """ Get the protocol (stress, rest) of a projection image by checking if it is part of the series description. :param projection: pydicom image of the projection :returns: the protocol as a string (Stress or Rest) """ if "stress" in projection.SeriesDescription.lower(): return "Stress" if "rest" in projection.SeriesDescription.lower(): return "Rest" raise ValueError(f"Unkown protocol in projection {projection.SeriesDescription}") def find_projections( dicom_images: List[pydicom.dataset.FileDataset], ) -> pydicom.dataset.FileDataset: """ Find all projections in a list of DICOM images belonging to a study. :param dicom_images: list of DICOM images of a study :return: the extracted DICOM image """ _filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images) dicom_images = [] for dicom_image in _filter: # filter for allows series descriptions if dicom_image.SeriesDescription not in ["Stress", "Rest", "Stress_2"]: logger.warning( f"Skip projection with unkown protocol [{dicom_image.SeriesDescription}]" ) # print(f" - {DICOMTime.Study.to_datetime(di)}, {DICOMTime.Series.to_datetime(di)}, {DICOMTime.Content.to_datetime(di)}, {DICOMTime.Acquisition.to_datetime(di)}") continue dicom_images.append(dicom_image) if len(dicom_images) == 0: raise ValueError(f"No projections available") return dicom_images def is_recon_type( scatter_corrected: bool, attenuation_corrected: bool ) -> Callable[[pydicom.dataset.FileDataset], bool]: """ Get a filter function for reconstructions that are (non-)scatter and/or (non-)attenuation corrected. :param scatter_corrected: if the filter should only return true for scatter corrected reconstructions :param attenuation_corrected: if the filter should only return true for attenuation corrected reconstructions :returns: a filter function """ if scatter_corrected and attenuation_corrected: filter_str = " SC - AC " elif not scatter_corrected and attenuation_corrected: filter_str = " NoSC - AC " elif scatter_corrected and not attenuation_corrected: filter_str = " SC - NoAC " elif not scatter_corrected and not attenuation_corrected: filter_str = " NoSC - NoAC " return lambda dicom: filter_str in dicom.SeriesDescription def find_reconstruction( dicom_images: List[pydicom.dataset.FileDataset], projection: pydicom.dataset.FileDataset, scatter_corrected: bool, attenuation_corrected: bool, ) -> List[pydicom.dataset.FileDataset]: """ Find a reconstruction in a list of dicom images of a study belonging to a projection. :param dicom_images: the list of dicom images belonging to the study :param projection: the dicom image of the projection :param scatter_corrected: if it should be searched fo a scatter corrected reconstruction :param attenuation_corrected: if it should be searched fo a attenuation corrected reconstruction :returns: the according reconstruction """ protocol = get_protocol(projection) _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol in image.SeriesDescription, _filter) _filter = filter( lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter ) _filter = filter( lambda image: "CT" not in image.SeriesDescription, _filter ) # remove µ-maps _filter = filter(is_recon_type(scatter_corrected, attenuation_corrected), _filter) # _filter = list(_filter) # print("DEBUG Reconstructions: ") # for r in _filter: # try: # print(f" - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}") # except Exception as e: # print(f"Error {e}") _filter = filter( lambda image: DICOMTime.Acquisition.to_datetime(image) == DICOMTime.Acquisition.to_datetime(projection), _filter, ) dicom_images = list(_filter) if len(dicom_images) == 0: raise ValueError( f"No reconstruction with SC={scatter_corrected}, AC={attenuation_corrected} available" ) # sort oldest to be first if len(dicom_images) > 1: logger.warning( f"Multiple reconstructions ({len(dicom_images)}) with SC={scatter_corrected}, AC={attenuation_corrected} for projection {projection.SeriesDescription} of patient {projection.PatientID}" ) dicom_images.sort( key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True ) return dicom_images[0] MAX_TIME_DIFF_S = 30 def find_attenuation_map( dicom_images: List[pydicom.dataset.FileDataset], projection: pydicom.dataset.FileDataset, reconstructions: List[pydicom.dataset.FileDataset], ) -> pydicom.dataset.FileDataset: """ Find a reconstruction in a list of dicom images of a study belonging to a projection and reconstructions. :param dicom_images: the list of dicom images belonging to the study :param projection: the dicom image of the projection :param reconstructions: dicom images of reconstructions belonging to the projection :returns: the according attenuation map """ protocol = get_protocol(projection) recon_times = list( map(lambda recon: DICOMTime.Series.to_datetime(recon), reconstructions) ) _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol in image.SeriesDescription, _filter) _filter = filter( lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter ) _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter) _filter = filter( lambda image: any( map( lambda recon_time: ( DICOMTime.Series.to_datetime(image) - recon_time ).seconds < MAX_TIME_DIFF_S, recon_times, ) ), _filter, ) dicom_images = list(_filter) if len(dicom_images) == 0: raise ValueError(f"No Attenuation map available") # sort oldest to be first if len(dicom_images) > 1: logger.warning( f"Multiple attenuation maps ({len(dicom_images)}) for projection {projection.SeriesDescription} of patient {projection.PatientID}" ) dicom_images.sort( key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True ) return dicom_images[0] def get_relevant_images( patient: pydicom.dataset.FileDataset, dicom_dir: str ) -> List[pydicom.dataset.FileDataset]: """ Get all relevant images of a patient. :param patient: pydicom dataset of a patient :param dicom_dir: the directory of the DICOM files :return: all relevant dicom images """ # get all myocardial scintigraphy studies studies = list( filter( lambda child: child.DirectoryRecordType == "STUDY", # and child.StudyDescription == "Myokardszintigraphie", # filter is disabled because there is a study without this description and only such studies are exported anyway patient.children, ) ) # extract all dicom images dicom_images = [] for study in studies: series = list( filter( lambda child: child.DirectoryRecordType == "SERIES", study.children, ) ) for _series in series: images = list( filter( lambda child: child.DirectoryRecordType == "IMAGE", _series.children, ) ) # all SPECT data is stored as a single 3D array which means that it is a series with a single image # this is not the case for CTs, which are skipped here if len(images) != 1: continue images = list( map( lambda image: pydicom.dcmread( os.path.join( dicom_dir_by_patient[patient.PatientID], *image.ReferencedFileID, ), stop_before_pixels=True, ), images, ) ) if len(images) == 0: continue dicom_images.append(images[0]) return dicom_images if __name__ == "__main__": parser = argparse.ArgumentParser( description="Prepare a dataset from DICOM directories", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "dicom_dirs", type=str, nargs="+", help="paths to DICOMDIR files or directories containing one of them", ) parser.add_argument( "--dataset_dir", type=str, required=True, help="directory where images, meta-information and the logs are stored", ) parser.add_argument( "--images_dir", type=str, default="images", help="sub-directory of --dataset_dir where images are stored", ) parser.add_argument( "--meta_csv", type=str, default="meta.csv", help="CSV file under --dataset_dir where meta-information is stored", ) add_logging_args( parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"} ) args = parser.parse_args() args.dicom_dirs = [ (os.path.dirname(_file) if os.path.isfile(_file) else _file) for _file in args.dicom_dirs ] args.images_dir = os.path.join(args.dataset_dir, args.images_dir) args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv) args.logfile = os.path.join(args.dataset_dir, args.logfile) if not os.path.exists(args.dataset_dir): os.mkdir(args.dataset_dir) if not os.path.exists(args.images_dir): os.mkdir(args.images_dir) global logger logger = get_logger_by_args(args) try: patients = [] dicom_dir_by_patient: Dict[str, str] = {} for dicom_dir in args.dicom_dirs: dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR")) for patient in dataset.patient_records: assert ( patient.PatientID not in dicom_dir_by_patient ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})" dicom_dir_by_patient[patient.PatientID] = dicom_dir patients.append(patient) _id = 1 if os.path.exists(args.meta_csv): data = pd.read_csv(args.meta_csv) _id = int(data[headers.id].max()) else: data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()])) for i, patient in enumerate(patients, start=1): logger.debug( f"Process patient {str(i):>3}/{len(patients)} - {patient.PatientName.given_name}, {patient.PatientName.family_name}:" ) dicom_images = get_relevant_images( patient, dicom_dir_by_patient[patient.PatientID] ) projections = find_projections(dicom_images) logger.debug( f"- Found {len(projections)}: projections with protocols {list(map(lambda p: get_protocol(p), projections))}" ) for projection in projections: protocol = get_protocol(projection) reconstructions = [] recon_headers = [ headers.file_recon_nac_nsc, headers.file_recon_nac_sc, headers.file_recon_ac_nsc, headers.file_recon_ac_sc, ] recon_postfixes = [ "recon_nac_nsc", "recon_nac_sc", "recon_ac_nsc", "recon_ac_sc", ] for ac, sc in [ (False, False), (False, True), (True, False), (True, True), ]: recon = find_reconstruction( dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac, ) reconstructions.append(recon) mu_map = find_attenuation_map(dicom_images, projection, reconstructions) # for i, projection in enumerate(projections): # _time = DICOMTime.Series.to_datetime(projection) # print(f" - Projection: {projection.SeriesDescription:>10} at {DICOMTime.Study.to_datetime(projection)}, {DICOMTime.Series.to_datetime(projection)}, {DICOMTime.Content.to_datetime(projection)}, {DICOMTime.Acquisition.to_datetime(projection)}") # recons = [] # for sc, ac in [(False, False), (False, True), (True, False), (True, True)]: # r = find_reconstruction(dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac) # print(f" - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}") # recons.append(r) # mu_map = find_attenuation_map(dicom_images, projection, recons) # print(f" - {mu_map.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(mu_map)}, {DICOMTime.Series.to_datetime(mu_map)}, {DICOMTime.Content.to_datetime(mu_map)}") # print(f" -") # extract pixel spacings and assert that they are equal for all reconstruction images _map_lists = map( lambda image: [*image.PixelSpacing, image.SliceThickness], [*reconstructions, mu_map], ) _map_lists = map( lambda pixel_spacing: list(map(float, pixel_spacing)), _map_lists, ) _map_ndarrays = map( lambda pixel_spacing: np.array(pixel_spacing), _map_lists ) pixel_spacings = list(_map_ndarrays) _equal = all( map( lambda pixel_spacing: ( pixel_spacing == pixel_spacings[0] ).all(), pixel_spacings, ) ) assert ( _equal ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}" pixel_spacing = pixel_spacings[0] # use the shape with the fewest slices, all other images will be aligned to that _map_lists = map( lambda image: [ image.Rows, image.Columns, image.NumberOfSlices, ], [*reconstructions, mu_map], ) _map_lists = map(lambda shape: list(map(int, shape)), _map_lists) _map_ndarrays = map(lambda shape: np.array(shape), _map_lists) shapes = list(_map_ndarrays) shapes.sort(key=lambda shape: shape[2]) shape = shapes[0] # extract and sort energy windows energy_windows = projection.EnergyWindowInformationSequence energy_windows = map( lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows ) energy_windows = map( lambda ew: ( float(ew.EnergyWindowLowerLimit), float(ew.EnergyWindowUpperLimit), ), energy_windows, ) energy_windows = list(energy_windows) energy_windows.sort(key=lambda ew: ew[0], reverse=True) row = { headers.id: _id, headers.patient_id: projection.PatientID, headers.age: parse_age(projection.PatientAge), headers.sex: projection.PatientSex, headers.weight: float(projection.PatientWeight), headers.size: float(projection.PatientSize), headers.patient_position: projection.PatientPosition, headers.protocol: protocol, headers.datetime_acquisition: DICOMTime.Series.to_datetime( projection ), headers.datetime_reconstruction: DICOMTime.Series.to_datetime( reconstructions[0] ), headers.pixel_spacing_x: pixel_spacing[0], headers.pixel_spacing_y: pixel_spacing[1], headers.pixel_spacing_z: pixel_spacing[2], headers.shape_x: shape[0], headers.shape_y: shape[1], headers.shape_z: shape[2], headers.radiopharmaceutical: projection.RadiopharmaceuticalInformationSequence[ 0 ].Radiopharmaceutical, headers.radionuclide_dose: projection.RadiopharmaceuticalInformationSequence[ 0 ].RadionuclideTotalDose, headers.radionuclide_code: projection.RadiopharmaceuticalInformationSequence[ 0 ] .RadionuclideCodeSequence[0] .CodeValue, headers.radionuclide_meaning: projection.RadiopharmaceuticalInformationSequence[ 0 ] .RadionuclideCodeSequence[0] .CodeMeaning, headers.energy_window_peak_lower: energy_windows[0][0], headers.energy_window_peak_upper: energy_windows[0][1], headers.energy_window_scatter_lower: energy_windows[1][0], headers.energy_window_scatter_upper: energy_windows[1][1], headers.detector_count: len(projection.DetectorInformationSequence), headers.collimator_type: projection.DetectorInformationSequence[ 0 ].CollimatorType, headers.rotation_start: float( projection.RotationInformationSequence[0].StartAngle ), headers.rotation_step: float( projection.RotationInformationSequence[0].AngularStep ), headers.rotation_scan_arc: float( projection.RotationInformationSequence[0].ScanArc ), } _filename_base = f"{_id:04d}-{protocol.lower()}" _ext = "dcm" img_headers = [ headers.file_projection, *recon_headers, headers.file_mu_map, ] img_postfixes = ["projection", *recon_postfixes, "mu_map"] images = [projection, *reconstructions, mu_map] for header, image, postfix in zip(img_headers, images, img_postfixes): _image = pydicom.dcmread(image.filename) filename = f"{_filename_base}-{postfix}.{_ext}" pydicom.dcmwrite(os.path.join(args.images_dir, filename), _image) row[header] = filename _id += 1 row = pd.DataFrame(row, index=[0]) data = pd.concat((data, row), ignore_index=True) data.to_csv(args.meta_csv, index=False) except Exception as e: logger.error(e) raise e