diff --git a/mu_map/data/prepare.py b/mu_map/data/prepare.py index e1735e2704b2b899c4373b1fd4fb9a39df29ba00..a5b1b87cd73152ab79f449acb09e2aab0f8ce6fe 100644 --- a/mu_map/data/prepare.py +++ b/mu_map/data/prepare.py @@ -51,7 +51,6 @@ headers.file_recon_nac_nsc = "file_recon_nac_nsc" headers.file_mu_map = "file_mu_map" - def get_protocol(projection: pydicom.dataset.FileDataset) -> str: """ Get the protocol (stress, rest) of a projection image by checking if @@ -69,7 +68,9 @@ def get_protocol(projection: pydicom.dataset.FileDataset) -> str: raise ValueError(f"Unkown protocol in projection {projection.SeriesDescription}") -def find_projections(dicom_images: List[pydicom.dataset.FileDataset]) -> pydicom.dataset.FileDataset: +def find_projections( + dicom_images: List[pydicom.dataset.FileDataset], +) -> pydicom.dataset.FileDataset: """ Find all projections in a list of DICOM images belonging to a study. @@ -81,7 +82,9 @@ def find_projections(dicom_images: List[pydicom.dataset.FileDataset]) -> pydicom for dicom_image in _filter: # filter for allows series descriptions if dicom_image.SeriesDescription not in ["Stress", "Rest", "Stress_2"]: - logger.warning(f"Skip projection with unkown protocol [{dicom_image.SeriesDescription}]") + logger.warning( + f"Skip projection with unkown protocol [{dicom_image.SeriesDescription}]" + ) # print(f" - {DICOMTime.Study.to_datetime(di)}, {DICOMTime.Series.to_datetime(di)}, {DICOMTime.Content.to_datetime(di)}, {DICOMTime.Acquisition.to_datetime(di)}") continue dicom_images.append(dicom_image) @@ -91,7 +94,10 @@ def find_projections(dicom_images: List[pydicom.dataset.FileDataset]) -> pydicom return dicom_images -def is_recon_type(scatter_corrected: bool, attenuation_corrected: bool) -> Callable[pydicom.dataset.FileDataset, bool]: + +def is_recon_type( + scatter_corrected: bool, attenuation_corrected: bool +) -> Callable[[pydicom.dataset.FileDataset], bool]: """ Get a filter function for reconstructions that are (non-)scatter and/or (non-)attenuation corrected. @@ -111,7 +117,12 @@ def is_recon_type(scatter_corrected: bool, attenuation_corrected: bool) -> Calla return lambda dicom: filter_str in dicom.SeriesDescription -def find_reconstruction(dicom_images: List[pydicom.dataset.FileDataset], projection: pydicom.dataset.FileDataset, scatter_corrected: bool, attenuation_corrected: bool) -> List[pydicom.dataset.FileDataset]: +def find_reconstruction( + dicom_images: List[pydicom.dataset.FileDataset], + projection: pydicom.dataset.FileDataset, + scatter_corrected: bool, + attenuation_corrected: bool, +) -> List[pydicom.dataset.FileDataset]: """ Find a reconstruction in a list of dicom images of a study belonging to a projection. @@ -125,35 +136,52 @@ def find_reconstruction(dicom_images: List[pydicom.dataset.FileDataset], project _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol in image.SeriesDescription, _filter) - _filter = filter(lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter) - _filter = filter(lambda image: "CT" not in image.SeriesDescription, _filter) # remove µ-maps + _filter = filter( + lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter + ) + _filter = filter( + lambda image: "CT" not in image.SeriesDescription, _filter + ) # remove µ-maps _filter = filter(is_recon_type(scatter_corrected, attenuation_corrected), _filter) # _filter = list(_filter) # print("DEBUG Reconstructions: ") # for r in _filter: - # try: - # print(f" - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}") - # except Exception as e: - # print(f"Error {e}") - - _filter = filter(lambda image: DICOMTime.Acquisition.to_datetime(image) == DICOMTime.Acquisition.to_datetime(projection), _filter) + # try: + # print(f" - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}") + # except Exception as e: + # print(f"Error {e}") + + _filter = filter( + lambda image: DICOMTime.Acquisition.to_datetime(image) + == DICOMTime.Acquisition.to_datetime(projection), + _filter, + ) dicom_images = list(_filter) if len(dicom_images) == 0: - raise ValueError(f"No reconstruction with SC={scatter_corrected}, AC={attenuation_corrected} available") + raise ValueError( + f"No reconstruction with SC={scatter_corrected}, AC={attenuation_corrected} available" + ) # sort oldest to be first if len(dicom_images) > 1: - logger.warning(f"Multiple reconstructions ({len(dicom_images)}) with SC={scatter_corrected}, AC={attenuation_corrected} for projection {projection.SeriesDescription} of patient {projection.PatientID}") - dicom_images.sort(key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True) + logger.warning( + f"Multiple reconstructions ({len(dicom_images)}) with SC={scatter_corrected}, AC={attenuation_corrected} for projection {projection.SeriesDescription} of patient {projection.PatientID}" + ) + dicom_images.sort( + key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True + ) return dicom_images[0] - MAX_TIME_DIFF_S = 30 + + def find_attenuation_map( - dicom_images: List[pydicom.dataset.FileDataset], projection: pydicom.dataset.FileDataset, reconstructions: List[pydicom.dataset.FileDataset], + dicom_images: List[pydicom.dataset.FileDataset], + projection: pydicom.dataset.FileDataset, + reconstructions: List[pydicom.dataset.FileDataset], ) -> pydicom.dataset.FileDataset: """ Find a reconstruction in a list of dicom images of a study belonging to a projection and reconstructions. @@ -164,13 +192,28 @@ def find_attenuation_map( :returns: the according attenuation map """ protocol = get_protocol(projection) - recon_times = list(map(lambda recon: DICOMTime.Series.to_datetime(recon), reconstructions)) + recon_times = list( + map(lambda recon: DICOMTime.Series.to_datetime(recon), reconstructions) + ) _filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images) _filter = filter(lambda image: protocol in image.SeriesDescription, _filter) - _filter = filter(lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter) + _filter = filter( + lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter + ) _filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter) - _filter = filter(lambda image: any(map(lambda recon_time: (DICOMTime.Series.to_datetime(image) - recon_time).seconds < MAX_TIME_DIFF_S, recon_times)), _filter) + _filter = filter( + lambda image: any( + map( + lambda recon_time: ( + DICOMTime.Series.to_datetime(image) - recon_time + ).seconds + < MAX_TIME_DIFF_S, + recon_times, + ) + ), + _filter, + ) dicom_images = list(_filter) if len(dicom_images) == 0: @@ -178,13 +221,19 @@ def find_attenuation_map( # sort oldest to be first if len(dicom_images) > 1: - logger.warning(f"Multiple attenuation maps ({len(dicom_images)}) for projection {projection.SeriesDescription} of patient {projection.PatientID}") - dicom_images.sort(key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True) + logger.warning( + f"Multiple attenuation maps ({len(dicom_images)}) for projection {projection.SeriesDescription} of patient {projection.PatientID}" + ) + dicom_images.sort( + key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True + ) return dicom_images[0] -def get_relevant_images(patient: pydicom.dataset.FileDataset, dicom_dir: str) -> List[pydicom.dataset.FileDataset]: +def get_relevant_images( + patient: pydicom.dataset.FileDataset, dicom_dir: str +) -> List[pydicom.dataset.FileDataset]: """ Get all relevant images of a patient. @@ -243,7 +292,6 @@ def get_relevant_images(patient: pydicom.dataset.FileDataset, dicom_dir: str) -> return dicom_images - if __name__ == "__main__": parser = argparse.ArgumentParser( description="Prepare a dataset from DICOM directories", @@ -315,38 +363,62 @@ if __name__ == "__main__": data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()])) for i, patient in enumerate(patients, start=1): - logger.debug(f"Process patient {str(i):>3}/{len(patients)} - {patient.PatientName.given_name}, {patient.PatientName.family_name}:") + logger.debug( + f"Process patient {str(i):>3}/{len(patients)} - {patient.PatientName.given_name}, {patient.PatientName.family_name}:" + ) - dicom_images = get_relevant_images(patient, dicom_dir_by_patient[patient.PatientID]) + dicom_images = get_relevant_images( + patient, dicom_dir_by_patient[patient.PatientID] + ) projections = find_projections(dicom_images) - logger.debug(f"- Found {len(projections)}: projections with protocols {list(map(lambda p: get_protocol(p), projections))}") + logger.debug( + f"- Found {len(projections)}: projections with protocols {list(map(lambda p: get_protocol(p), projections))}" + ) for projection in projections: protocol = get_protocol(projection) reconstructions = [] - recon_headers = [headers.file_recon_nac_nsc, headers.file_recon_nac_sc, headers.file_recon_ac_nsc, headers.file_recon_ac_sc] - recon_postfixes = ["recon_nac_nsc", "recon_nac_sc", "recon_ac_nsc", "recon_ac_sc"] - for ac, sc in [(False, False), (False, True), (True, False), (True, True)]: - recon = find_reconstruction(dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac) + recon_headers = [ + headers.file_recon_nac_nsc, + headers.file_recon_nac_sc, + headers.file_recon_ac_nsc, + headers.file_recon_ac_sc, + ] + recon_postfixes = [ + "recon_nac_nsc", + "recon_nac_sc", + "recon_ac_nsc", + "recon_ac_sc", + ] + for ac, sc in [ + (False, False), + (False, True), + (True, False), + (True, True), + ]: + recon = find_reconstruction( + dicom_images, + projection, + scatter_corrected=sc, + attenuation_corrected=ac, + ) reconstructions.append(recon) mu_map = find_attenuation_map(dicom_images, projection, reconstructions) - - # for i, projection in enumerate(projections): + # for i, projection in enumerate(projections): # _time = DICOMTime.Series.to_datetime(projection) # print(f" - Projection: {projection.SeriesDescription:>10} at {DICOMTime.Study.to_datetime(projection)}, {DICOMTime.Series.to_datetime(projection)}, {DICOMTime.Content.to_datetime(projection)}, {DICOMTime.Acquisition.to_datetime(projection)}") # recons = [] # for sc, ac in [(False, False), (False, True), (True, False), (True, True)]: - # r = find_reconstruction(dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac) - # print(f" - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}") - # recons.append(r) + # r = find_reconstruction(dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac) + # print(f" - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}") + # recons.append(r) # mu_map = find_attenuation_map(dicom_images, projection, recons) # print(f" - {mu_map.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(mu_map)}, {DICOMTime.Series.to_datetime(mu_map)}, {DICOMTime.Content.to_datetime(mu_map)}") # print(f" -") - # extract pixel spacings and assert that they are equal for all reconstruction images _map_lists = map( lambda image: [*image.PixelSpacing, image.SliceThickness], @@ -382,18 +454,14 @@ if __name__ == "__main__": ], [*reconstructions, mu_map], ) - _map_lists = map( - lambda shape: list(map(int, shape)), _map_lists - ) + _map_lists = map(lambda shape: list(map(int, shape)), _map_lists) _map_ndarrays = map(lambda shape: np.array(shape), _map_lists) shapes = list(_map_ndarrays) shapes.sort(key=lambda shape: shape[2]) shape = shapes[0] # extract and sort energy windows - energy_windows = ( - projection.EnergyWindowInformationSequence - ) + energy_windows = projection.EnergyWindowInformationSequence energy_windows = map( lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows ) @@ -414,40 +482,66 @@ if __name__ == "__main__": headers.weight: float(projection.PatientWeight), headers.size: float(projection.PatientSize), headers.protocol: protocol, - headers.datetime_acquisition: DICOMTime.Series.to_datetime(projection), - headers.datetime_reconstruction: DICOMTime.Series.to_datetime(reconstructions[0]), + headers.datetime_acquisition: DICOMTime.Series.to_datetime( + projection + ), + headers.datetime_reconstruction: DICOMTime.Series.to_datetime( + reconstructions[0] + ), headers.pixel_spacing_x: pixel_spacing[0], headers.pixel_spacing_y: pixel_spacing[1], headers.pixel_spacing_z: pixel_spacing[2], headers.shape_x: shape[0], headers.shape_y: shape[1], headers.shape_z: shape[2], - headers.radiopharmaceutical: projection.RadiopharmaceuticalInformationSequence[0].Radiopharmaceutical, - headers.radionuclide_dose: projection.RadiopharmaceuticalInformationSequence[0].RadionuclideTotalDose, - headers.radionuclide_code: projection.RadiopharmaceuticalInformationSequence[0].RadionuclideCodeSequence[0].CodeValue, - headers.radionuclide_meaning: projection.RadiopharmaceuticalInformationSequence[0].RadionuclideCodeSequence[0].CodeMeaning, + headers.radiopharmaceutical: projection.RadiopharmaceuticalInformationSequence[ + 0 + ].Radiopharmaceutical, + headers.radionuclide_dose: projection.RadiopharmaceuticalInformationSequence[ + 0 + ].RadionuclideTotalDose, + headers.radionuclide_code: projection.RadiopharmaceuticalInformationSequence[ + 0 + ] + .RadionuclideCodeSequence[0] + .CodeValue, + headers.radionuclide_meaning: projection.RadiopharmaceuticalInformationSequence[ + 0 + ] + .RadionuclideCodeSequence[0] + .CodeMeaning, headers.energy_window_peak_lower: energy_windows[0][0], headers.energy_window_peak_upper: energy_windows[0][1], headers.energy_window_scatter_lower: energy_windows[1][0], headers.energy_window_scatter_upper: energy_windows[1][1], headers.detector_count: len(projection.DetectorInformationSequence), - headers.collimator_type: projection.DetectorInformationSequence[0].CollimatorType, - headers.rotation_start: float(projection.RotationInformationSequence[0].StartAngle), - headers.rotation_step: float(projection.RotationInformationSequence[0].AngularStep), - headers.rotation_scan_arc: float(projection.RotationInformationSequence[0].ScanArc), + headers.collimator_type: projection.DetectorInformationSequence[ + 0 + ].CollimatorType, + headers.rotation_start: float( + projection.RotationInformationSequence[0].StartAngle + ), + headers.rotation_step: float( + projection.RotationInformationSequence[0].AngularStep + ), + headers.rotation_scan_arc: float( + projection.RotationInformationSequence[0].ScanArc + ), } _filename_base = f"{_id:04d}-{protocol.lower()}" _ext = "dcm" - img_headers = [headers.file_projection, *recon_headers, headers.file_mu_map] + img_headers = [ + headers.file_projection, + *recon_headers, + headers.file_mu_map, + ] img_postfixes = ["projection", *recon_postfixes, "mu_map"] images = [projection, *reconstructions, mu_map] for header, image, postfix in zip(img_headers, images, img_postfixes): _image = pydicom.dcmread(image.filename) filename = f"{_filename_base}-{postfix}.{_ext}" - pydicom.dcmwrite( - os.path.join(args.images_dir, filename), _image - ) + pydicom.dcmwrite(os.path.join(args.images_dir, filename), _image) row[header] = filename _id += 1