diff --git a/mu_map/data/prepare.py b/mu_map/data/prepare.py new file mode 100644 index 0000000000000000000000000000000000000000..a7358029f57c26d8fd2b810ae85b43af008af453 --- /dev/null +++ b/mu_map/data/prepare.py @@ -0,0 +1,414 @@ +import argparse +from datetime import datetime, timedelta +from enum import Enum +import os +from typing import List + +import numpy as np +import pandas as pd +import pydicom + + +class MyocardialProtocol(Enum): + Stress = 1 + Rest = 2 + + +headers = argparse.Namespace() +headers.id = "id" +headers.patient_id = "patient_id" +headers.age = "age" +headers.weight = "weight" +headers.size = "size" +headers.protocol = "protocol" +headers.datetime_acquisition = "datetime_acquisition" +headers.datetime_reconstruction = "datetime_reconstruction" +headers.pixel_spacing_x = "pixel_spacing_x" +headers.pixel_spacing_y = "pixel_spacing_y" +headers.pixel_spacing_z = "pixel_spacing_z" +headers.shape_x = "shape_x" +headers.shape_y = "shape_y" +headers.shape_z = "shape_z" +headers.radiopharmaceutical = "radiopharmaceutical" +headers.radionuclide_dose = "radionuclide_dose" +headers.radionuclide_code = "radionuclide_code" +headers.radionuclide_meaning = "radionuclide_meaning" +headers.energy_window_peak_lower = "energy_window_peak_lower" +headers.energy_window_peak_upper = "energy_window_peak_upper" +headers.energy_window_scatter_lower = "energy_window_scatter_lower" +headers.energy_window_scatter_upper = "energy_window_scatter_upper" +headers.detector_count = "detector_count" +headers.collimator_type = "collimator_type" +headers.rotation_start = "rotation_start" +headers.rotation_step = "rotation_step" +headers.rotation_scan_arc = "rotation_scan_arc" +headers.file_recon_ac = "file_recon_ac" +headers.file_recon_no_ac = "file_recon_no_ac" +headers.file_mu_map = "file_mu_map" + + + +def parse_series_time(dicom_image: pydicom.dataset.FileDataset) -> datetime: + """ + Parse the date and time of a DICOM series object into a datetime object. + + :param dicom_image: the dicom file to parse the series date and time from + :return: an according python datetime object. + """ + _date = dicom_image.SeriesDate + _time = dicom_image.SeriesTime + return datetime( + year=int(_date[0:4]), + month=int(_date[4:6]), + day=int(_date[6:8]), + hour=int(_time[0:2]), + minute=int(_time[2:4]), + second=int(_time[4:6]), + microsecond=int(_time.split(".")[1]), + ) + + +def parse_age(patient_age: str) -> int: + """ + Parse and age string as defined in the DICOM standard into an integer representing the age in years. + + :param patient_age: age string as defined in the DICOM standard + :return: the age in years as a number + """ + assert ( + type(patient_age) == str + ), f"patient age needs to be a string and not {type(patient_age)}" + assert ( + len(patient_age) == 4 + ), f"patient age [{patient_age}] has to be four characters long" + _num, _format = patient_age[:3], patient_age[3] + assert ( + _format == "Y" + ), f"currently, only patient ages in years [Y] is supported, not [{_format}]" + return int(_num) + + +def get_projection(dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol) -> pydicom.dataset.FileDataset: + """ + Extract the SPECT projection from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol. + + :param dicom_images: list of DICOM images of a study + :param protocol: the protocol for which the projection images should be extracted + :return: the extracted DICOM image + """ + dicom_images = filter(lambda image: "TOMO" in image.ImageType, dicom_images) + dicom_images = filter(lambda image: protocol.name in image.SeriesDescription, dicom_images) + dicom_images = list(dicom_images) + + if len(dicom_images) != 1: + raise ValueError(f"No or multiple projections {len(dicom_images)} for protocol {protocol.name} available") + + return dicom_images[0] + + +def get_reconstruction(dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol, corrected:bool=True) -> pydicom.dataset.FileDataset: + """ + Extract a SPECT reconstruction from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol. + The corrected flag can be used to either extract an attenuation corrected or a non-attenuation corrected image. + If there are multiple images, they are sorted by acquisition date and the newest is returned. + + :param dicom_images: list of DICOM images of a study + :param protocol: the protocol for which the projection images should be extracted + :param corrected: extract an attenuation or non-attenuation corrected image + :return: the extracted DICOM image + """ + dicom_images = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) + dicom_images = filter(lambda image: protocol.name in image.SeriesDescription, dicom_images) + + if corrected: + dicom_images = filter( + lambda image: "AC" in image.SeriesDescription + and "NoAC" not in image.SeriesDescription, + dicom_images, + ) + dicom_images = list(dicom_images) + else: + dicom_images = filter( + lambda image: "NoAC" in image.SeriesDescription, dicom_images + ) + + # for SPECT reconstructions created in clinical studies this value exists and is set to 'APEX_TO_BASE' + # for the reconstructions with attenuation maps it does not exist + dicom_images = filter( + lambda image: not hasattr(image, "SliceProgressionDirection"), dicom_images + ) + + dicom_images = list(dicom_images) + dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True) + + if len(dicom_images) == 0: + _str = "AC" if corrected else "NoAC" + raise ValueError(f"{_str} Reconstruction for protocol {protocol.name} is not available") + + return dicom_images[0] + + +def get_attenuation_map(dicom_images: List[pydicom.dataset.FileDataset], protocol: MyocardialProtocol) -> pydicom.dataset.FileDataset: + """ + Extract an attenuation map from a list of DICOM images belonging to a myocardial scintigraphy study given a study protocol. + If there are multiple attenuation maps, they are sorted by acquisition date and the newest is returned. + + :param dicom_images: list of DICOM images of a study + :param protocol: the protocol for which the projection images should be extracted + :return: the extracted DICOM image + """ + dicom_images = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) + dicom_images = filter(lambda image: protocol.name in image.SeriesDescription, dicom_images) + dicom_images = filter(lambda image: "µ-map" in image.SeriesDescription, dicom_images) + dicom_images = list(dicom_images) + dicom_images.sort(key=lambda image: parse_series_time(image), reverse=True) + + if len(dicom_images) == 0: + raise ValueError(f"Attenuation map for protocol {protocol.name} is not available") + + return dicom_images[0] + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Prepare a dataset from DICOM directories", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "dicom_dirs", + type=str, + nargs="+", + help="paths to DICOMDIR files or directories containing one of them", + ) + parser.add_argument("--dataset_dir", type=str, required=True, help="") + parser.add_argument("--images_dir", type=str, default="images", help="") + parser.add_argument("--csv", type=str, default="data.csv", help="") + parser.add_argument("--prefix_projection", type=str, default="projection", help="") + parser.add_argument("--prefix_mu_map", type=str, default="mu_map", help="") + parser.add_argument("--prefix_recon_ac", type=str, default="recon_ac", help="") + parser.add_argument("--prefix_recon_no_ac", type=str, default="recon_no_ac", help="") + args = parser.parse_args() + + args.dicom_dirs = [ + (os.path.dirname(_file) if os.path.isfile(_file) else _file) + for _file in args.dicom_dirs + ] + args.images_dir = os.path.join(args.dataset_dir, args.images_dir) + args.csv = os.path.join(args.dataset_dir, args.csv) + + patients = [] + dicom_dir_by_patient = {} + for dicom_dir in args.dicom_dirs: + dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR")) + for patient in dataset.patient_records: + assert ( + patient.PatientID not in dicom_dir_by_patient + ), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})" + dicom_dir_by_patient[patient.PatientID] = dicom_dir + patients.append(patient) + + if not os.path.exists(args.dataset_dir): + os.mkdir(args.dataset_dir) + + if not os.path.exists(args.images_dir): + os.mkdir(args.images_dir) + + _id = 1 + if os.path.exists(args.csv): + data = pd.read_csv(args.csv) + _id = int(data[headers.id].max()) + else: + data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()])) + + + for i, patient in enumerate(patients, start=1): + print(f"Process patient {str(i):>3}/{len(patients)}:") + + # get all myocardial scintigraphy studies + studies = list( + filter( + lambda child: child.DirectoryRecordType == "STUDY" + and child.StudyDescription == "Myokardszintigraphie", + patient.children, + ) + ) + + # extract all dicom images + dicom_images = [] + for study in studies: + series = list( + filter(lambda child: child.DirectoryRecordType == "SERIES", study.children) + ) + for _series in series: + images = list( + filter( + lambda child: child.DirectoryRecordType == "IMAGE", _series.children + ) + ) + + # all SPECT data is stored as a single 3D array which means that it is a series with a single image + # this is not the case for CTs, which are skipped here + if len(images) != 1: + continue + + images = list( + map( + lambda image: pydicom.dcmread( + os.path.join(dicom_dir_by_patient[patient.PatientID], *image.ReferencedFileID), + stop_before_pixels=True, + ), + images, + ) + ) + + if len(images) == 0: + continue + + dicom_images.append(images[0]) + + + for protocol in MyocardialProtocol: + if len(data[(data[headers.patient_id] == patient.PatientID) & (data[headers.protocol] == protocol.name)]) > 0: + print(f"Skip {patient.PatientID}:{protocol.name} since it is already contained in the dataset") + continue + + try: + projection_image = get_projection(dicom_images, protocol=protocol) + recon_ac = get_reconstruction(dicom_images, protocol=protocol, corrected=True) + recon_noac = get_reconstruction(dicom_images, protocol=protocol, corrected=False) + attenuation_map = get_attenuation_map(dicom_images, protocol=protocol) + except ValueError as e: + print(f"Skip {patient.PatientID}:{protocol.name} because {e}") + continue + + recon_images = [recon_ac, recon_noac, attenuation_map] + + # extract date times and assert that they are equal for all reconstruction images + datetimes = list(map(parse_series_time, recon_images)) + _datetimes = sorted(datetimes, reverse=True) + _datetimes_delta = list(map(lambda dt: _datetimes[0] - dt, _datetimes)) + _equal = all(map(lambda dt: dt < timedelta(seconds=300), _datetimes_delta)) + assert ( + _equal + ), f"Not all dates and times of the reconstructions are equal: {datetimes}" + + # extract pixel spacings and assert that they are equal for all reconstruction images + pixel_spacings = map( + lambda image: [*image.PixelSpacing, image.SliceThickness], recon_images + ) + pixel_spacings = map( + lambda pixel_spacing: list(map(float, pixel_spacing)), pixel_spacings + ) + pixel_spacings = map( + lambda pixel_spacing: np.array(pixel_spacing), pixel_spacings + ) + pixel_spacings = list(pixel_spacings) + _equal = all( + map( + lambda pixel_spacing: (pixel_spacing == pixel_spacings[0]).all(), + pixel_spacings, + ) + ) + assert ( + _equal + ), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}" + pixel_spacing = pixel_spacings[0] + + # extract shapes and assert that they are equal for all reconstruction images + shapes = map( + lambda image: [image.Rows, image.Columns, image.NumberOfSlices], + recon_images, + ) + shapes = map(lambda shape: list(map(int, shape)), shapes) + shapes = map(lambda shape: np.array(shape), shapes) + shapes = list(shapes) + _equal = all(map(lambda shape: (shape == shapes[0]).all(), shapes)) + # assert _equal, f"Not all shapes of the reconstructions are equal: {shapes}" + # print(shapes) + shape = shapes[0] + + # exctract and sort energy windows + energy_windows = projection_image.EnergyWindowInformationSequence + energy_windows = map( + lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows + ) + energy_windows = map( + lambda ew: ( + float(ew.EnergyWindowLowerLimit), + float(ew.EnergyWindowUpperLimit), + ), + energy_windows, + ) + energy_windows = list(energy_windows) + energy_windows.sort(key=lambda ew: ew[0], reverse=True) + + # re-read images with pixel-level data and save accordingly + projection_image = pydicom.dcmread(projection_image.filename) + recon_ac = pydicom.dcmread(recon_ac.filename) + recon_noac = pydicom.dcmread(recon_noac.filename) + attenuation_map = pydicom.dcmread(attenuation_map.filename) + + pydicom.dcmwrite(os.path.join(args.images_dir, f"{_id:04d}-{protocol.name.lower()}-{args.prefix_projection}.dcm"), projection_image) + pydicom.dcmwrite(os.path.join(args.images_dir, f"{_id:04d}-{protocol.name.lower()}-{args.prefix_recon_ac}.dcm"), recon_ac) + pydicom.dcmwrite(os.path.join(args.images_dir, f"{_id:04d}-{protocol.name.lower()}-{args.prefix_recon_no_ac}.dcm"), recon_noac) + pydicom.dcmwrite(os.path.join(args.images_dir, f"{_id:04d}-{protocol.name.lower()}-{args.prefix_mu_map}.dcm"), attenuation_map) + + row = { + headers.id: _id, + headers.patient_id: projection_image.PatientID, + headers.age: parse_age(projection_image.PatientAge), + headers.weight: float(projection_image.PatientWeight), + headers.size: float(projection_image.PatientSize), + headers.protocol: protocol.name, + headers.datetime_acquisition: parse_series_time(projection_image), + headers.datetime_reconstruction: datetimes[0], + headers.pixel_spacing_x: pixel_spacing[0], + headers.pixel_spacing_y: pixel_spacing[1], + headers.pixel_spacing_z: pixel_spacing[2], + headers.shape_x: shape[0], + headers.shape_y: shape[1], + headers.shape_z: shape[2], + headers.radiopharmaceutical: projection_image.RadiopharmaceuticalInformationSequence[ + 0 + ].Radiopharmaceutical, + headers.radionuclide_dose: projection_image.RadiopharmaceuticalInformationSequence[ + 0 + ].RadionuclideTotalDose, + headers.radionuclide_code: projection_image.RadiopharmaceuticalInformationSequence[ + 0 + ] + .RadionuclideCodeSequence[0] + .CodeValue, + headers.radionuclide_meaning: projection_image.RadiopharmaceuticalInformationSequence[ + 0 + ] + .RadionuclideCodeSequence[0] + .CodeMeaning, + headers.energy_window_peak_lower: energy_windows[0][0], + headers.energy_window_peak_upper: energy_windows[0][1], + headers.energy_window_scatter_lower: energy_windows[1][0], + headers.energy_window_scatter_upper: energy_windows[1][1], + headers.detector_count: len(projection_image.DetectorInformationSequence), + headers.collimator_type: projection_image.DetectorInformationSequence[ + 0 + ].CollimatorType, + headers.rotation_start: float( + projection_image.RotationInformationSequence[0].StartAngle + ), + headers.rotation_step: float( + projection_image.RotationInformationSequence[0].AngularStep + ), + headers.rotation_scan_arc: float( + projection_image.RotationInformationSequence[0].ScanArc + ), + headers.file_recon_ac: "filename_recon_ac.dcm", + headers.file_recon_no_ac: "filename_recon_no_ac.dcm", + headers.file_mu_map: "filanem_mu_map.dcm", + } + _id += 1 + + row = pd.DataFrame(row, index=[0]) + data = pd.concat((data, row), ignore_index=True) + +data.to_csv(args.csv, index=False) +