Newer
Older
import argparse
from datetime import datetime, timedelta
from enum import Enum
import os

Tamino Huxohl
committed
from typing import List, Dict, Callable
import numpy as np
import pandas as pd
import pydicom

Tamino Huxohl
committed
from mu_map.data.dicom_util import DICOMTime, parse_age
from mu_map.logging import add_logging_args, get_logger_by_args
STUDY_DESCRIPTION = "µ-map_study"
headers = argparse.Namespace()
headers.id = "id"
headers.patient_id = "patient_id"
headers.age = "age"
headers.weight = "weight"
headers.size = "size"
headers.protocol = "protocol"
headers.datetime_acquisition = "datetime_acquisition"
headers.datetime_reconstruction = "datetime_reconstruction"
headers.pixel_spacing_x = "pixel_spacing_x"
headers.pixel_spacing_y = "pixel_spacing_y"
headers.pixel_spacing_z = "pixel_spacing_z"
headers.shape_x = "shape_x"
headers.shape_y = "shape_y"
headers.shape_z = "shape_z"
headers.radiopharmaceutical = "radiopharmaceutical"
headers.radionuclide_dose = "radionuclide_dose"
headers.radionuclide_code = "radionuclide_code"
headers.radionuclide_meaning = "radionuclide_meaning"
headers.energy_window_peak_lower = "energy_window_peak_lower"
headers.energy_window_peak_upper = "energy_window_peak_upper"
headers.energy_window_scatter_lower = "energy_window_scatter_lower"
headers.energy_window_scatter_upper = "energy_window_scatter_upper"
headers.detector_count = "detector_count"
headers.collimator_type = "collimator_type"
headers.rotation_start = "rotation_start"
headers.rotation_step = "rotation_step"
headers.rotation_scan_arc = "rotation_scan_arc"
headers.file_projection = "file_projection"

Tamino Huxohl
committed
headers.file_recon_ac_sc = "file_recon_ac_sc"
headers.file_recon_nac_sc = "file_recon_nac_sc"
headers.file_recon_ac_nsc = "file_recon_ac_nsc"
headers.file_recon_nac_nsc = "file_recon_nac_nsc"
headers.file_mu_map = "file_mu_map"

Tamino Huxohl
committed
def get_protocol(projection: pydicom.dataset.FileDataset) -> str:

Tamino Huxohl
committed
Get the protocol (stress, rest) of a projection image by checking if
it is part of the series description.

Tamino Huxohl
committed
:param projection: pydicom image of the projection
:returns: the protocol as a string (Stress or Rest)

Tamino Huxohl
committed
if "stress" in projection.SeriesDescription.lower():
return "Stress"

Tamino Huxohl
committed
if "rest" in projection.SeriesDescription.lower():
return "Rest"

Tamino Huxohl
committed
raise ValueError(f"Unkown protocol in projection {projection.SeriesDescription}")

Tamino Huxohl
committed
def find_projections(dicom_images: List[pydicom.dataset.FileDataset]) -> pydicom.dataset.FileDataset:

Tamino Huxohl
committed
Find all projections in a list of DICOM images belonging to a study.
:param dicom_images: list of DICOM images of a study
:return: the extracted DICOM image
"""
_filter = filter(lambda image: "TOMO" in image.ImageType, dicom_images)

Tamino Huxohl
committed
dicom_images = []
for dicom_image in _filter:
# filter for allows series descriptions
if dicom_image.SeriesDescription not in ["Stress", "Rest", "Stress_2"]:
logger.warning(f"Skip projection with unkown protocol [{dicom_image.SeriesDescription}]")
# print(f" - {DICOMTime.Study.to_datetime(di)}, {DICOMTime.Series.to_datetime(di)}, {DICOMTime.Content.to_datetime(di)}, {DICOMTime.Acquisition.to_datetime(di)}")
continue
dicom_images.append(dicom_image)

Tamino Huxohl
committed
if len(dicom_images) == 0:

Tamino Huxohl
committed
raise ValueError(f"No projections available")

Tamino Huxohl
committed
return dicom_images

Tamino Huxohl
committed
def is_recon_type(scatter_corrected: bool, attenuation_corrected: bool) -> Callable[pydicom.dataset.FileDataset, bool]:

Tamino Huxohl
committed
Get a filter function for reconstructions that are (non-)scatter and/or (non-)attenuation corrected.

Tamino Huxohl
committed
:param scatter_corrected: if the filter should only return true for scatter corrected reconstructions
:param attenuation_corrected: if the filter should only return true for attenuation corrected reconstructions
:returns: a filter function

Tamino Huxohl
committed
if scatter_corrected and attenuation_corrected:
filter_str = " SC - AC "
elif not scatter_corrected and attenuation_corrected:
filter_str = " NoSC - AC "
elif scatter_corrected and not attenuation_corrected:
filter_str = " SC - NoAC "
elif not scatter_corrected and not attenuation_corrected:
filter_str = " NoSC - NoAC "

Tamino Huxohl
committed
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
return lambda dicom: filter_str in dicom.SeriesDescription
def find_reconstruction(dicom_images: List[pydicom.dataset.FileDataset], projection: pydicom.dataset.FileDataset, scatter_corrected: bool, attenuation_corrected: bool) -> List[pydicom.dataset.FileDataset]:
"""
Find a reconstruction in a list of dicom images of a study belonging to a projection.
:param dicom_images: the list of dicom images belonging to the study
:param projection: the dicom image of the projection
:param scatter_corrected: if it should be searched fo a scatter corrected reconstruction
:param attenuation_corrected: if it should be searched fo a attenuation corrected reconstruction
:returns: the according reconstruction
"""
protocol = get_protocol(projection)
_filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)
_filter = filter(lambda image: protocol in image.SeriesDescription, _filter)
_filter = filter(lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter)
_filter = filter(lambda image: "CT" not in image.SeriesDescription, _filter) # remove µ-maps
_filter = filter(is_recon_type(scatter_corrected, attenuation_corrected), _filter)
# _filter = list(_filter)
# print("DEBUG Reconstructions: ")
# for r in _filter:
# try:
# print(f" - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}")
# except Exception as e:
# print(f"Error {e}")
_filter = filter(lambda image: DICOMTime.Acquisition.to_datetime(image) == DICOMTime.Acquisition.to_datetime(projection), _filter)
if len(dicom_images) == 0:

Tamino Huxohl
committed
raise ValueError(f"No reconstruction with SC={scatter_corrected}, AC={attenuation_corrected} available")
# sort oldest to be first
if len(dicom_images) > 1:
logger.warning(f"Multiple reconstructions ({len(dicom_images)}) with SC={scatter_corrected}, AC={attenuation_corrected} for projection {projection.SeriesDescription} of patient {projection.PatientID}")
dicom_images.sort(key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True)
return dicom_images[0]

Tamino Huxohl
committed
MAX_TIME_DIFF_S = 30
def find_attenuation_map(
dicom_images: List[pydicom.dataset.FileDataset], projection: pydicom.dataset.FileDataset, reconstructions: List[pydicom.dataset.FileDataset],
) -> pydicom.dataset.FileDataset:

Tamino Huxohl
committed
Find a reconstruction in a list of dicom images of a study belonging to a projection and reconstructions.

Tamino Huxohl
committed
:param dicom_images: the list of dicom images belonging to the study
:param projection: the dicom image of the projection
:param reconstructions: dicom images of reconstructions belonging to the projection
:returns: the according attenuation map

Tamino Huxohl
committed
protocol = get_protocol(projection)
recon_times = list(map(lambda recon: DICOMTime.Series.to_datetime(recon), reconstructions))
_filter = filter(lambda image: "RECON TOMO" in image.ImageType, dicom_images)

Tamino Huxohl
committed
_filter = filter(lambda image: protocol in image.SeriesDescription, _filter)
_filter = filter(lambda image: STUDY_DESCRIPTION in image.SeriesDescription, _filter)
_filter = filter(lambda image: " µ-map]" in image.SeriesDescription, _filter)

Tamino Huxohl
committed
_filter = filter(lambda image: any(map(lambda recon_time: (DICOMTime.Series.to_datetime(image) - recon_time).seconds < MAX_TIME_DIFF_S, recon_times)), _filter)

Tamino Huxohl
committed
dicom_images = list(_filter)
if len(dicom_images) == 0:

Tamino Huxohl
committed
raise ValueError(f"No Attenuation map available")
# sort oldest to be first
if len(dicom_images) > 1:
logger.warning(f"Multiple attenuation maps ({len(dicom_images)}) for projection {projection.SeriesDescription} of patient {projection.PatientID}")
dicom_images.sort(key=lambda image: DICOMTime.Series.to_datetime(image), reverse=True)
return dicom_images[0]
def get_relevant_images(patient: pydicom.dataset.FileDataset, dicom_dir: str) -> List[pydicom.dataset.FileDataset]:
"""
Get all relevant images of a patient.
:param patient: pydicom dataset of a patient
:param dicom_dir: the directory of the DICOM files
:return: all relevant dicom images
"""
# get all myocardial scintigraphy studies
studies = list(
filter(
lambda child: child.DirectoryRecordType == "STUDY",
# and child.StudyDescription == "Myokardszintigraphie", # filter is disabled because there is a study without this description and only such studies are exported anyway
patient.children,

Tamino Huxohl
committed
)

Tamino Huxohl
committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
# extract all dicom images
dicom_images = []
for study in studies:
series = list(
filter(
lambda child: child.DirectoryRecordType == "SERIES",
study.children,
)
)
for _series in series:
images = list(
filter(
lambda child: child.DirectoryRecordType == "IMAGE",
_series.children,
)
)
# all SPECT data is stored as a single 3D array which means that it is a series with a single image
# this is not the case for CTs, which are skipped here
if len(images) != 1:
continue
images = list(
map(
lambda image: pydicom.dcmread(
os.path.join(
dicom_dir_by_patient[patient.PatientID],
*image.ReferencedFileID,
),
stop_before_pixels=True,
),
images,
)
)
if len(images) == 0:
continue
dicom_images.append(images[0])

Tamino Huxohl
committed
return dicom_images

Tamino Huxohl
committed
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Prepare a dataset from DICOM directories",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dicom_dirs",
type=str,
nargs="+",
help="paths to DICOMDIR files or directories containing one of them",
)
parser.add_argument(
"--dataset_dir",
type=str,
required=True,
help="directory where images, meta-information and the logs are stored",
)
parser.add_argument(
"--images_dir",
type=str,
default="images",
help="sub-directory of --dataset_dir where images are stored",
)
parser.add_argument(
"--meta_csv",
type=str,
default="meta.csv",
help="CSV file under --dataset_dir where meta-information is stored",
)
add_logging_args(
parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
)
args = parser.parse_args()
args.dicom_dirs = [
(os.path.dirname(_file) if os.path.isfile(_file) else _file)
for _file in args.dicom_dirs
]
args.images_dir = os.path.join(args.dataset_dir, args.images_dir)
args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv)
args.logfile = os.path.join(args.dataset_dir, args.logfile)
if not os.path.exists(args.dataset_dir):
os.mkdir(args.dataset_dir)
if not os.path.exists(args.images_dir):
os.mkdir(args.images_dir)
global logger
logger = get_logger_by_args(args)
try:
patients = []
dicom_dir_by_patient: Dict[str, str] = {}
for dicom_dir in args.dicom_dirs:
dataset = pydicom.dcmread(os.path.join(dicom_dir, "DICOMDIR"))
for patient in dataset.patient_records:
assert (
patient.PatientID not in dicom_dir_by_patient
), f"Patient {patient.PatientID} is contained twice in the given DICOM directories ({dicom_dir} and {dicom_dir_by_patient[patient.PatientID]})"
dicom_dir_by_patient[patient.PatientID] = dicom_dir
patients.append(patient)
_id = 1
if os.path.exists(args.meta_csv):
data = pd.read_csv(args.meta_csv)
_id = int(data[headers.id].max())
else:
data = pd.DataFrame(dict([(key, []) for key in vars(headers).keys()]))
for i, patient in enumerate(patients, start=1):

Tamino Huxohl
committed
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
logger.debug(f"Process patient {str(i):>3}/{len(patients)} - {patient.PatientName.given_name}, {patient.PatientName.family_name}:")
dicom_images = get_relevant_images(patient, dicom_dir_by_patient[patient.PatientID])
projections = find_projections(dicom_images)
logger.debug(f"- Found {len(projections)}: projections with protocols {list(map(lambda p: get_protocol(p), projections))}")
for projection in projections:
protocol = get_protocol(projection)
reconstructions = []
recon_headers = [headers.file_recon_nac_nsc, headers.file_recon_nac_sc, headers.file_recon_ac_nsc, headers.file_recon_ac_sc]
recon_postfixes = ["recon_nac_nsc", "recon_nac_sc", "recon_ac_nsc", "recon_ac_sc"]
for ac, sc in [(False, False), (False, True), (True, False), (True, True)]:
recon = find_reconstruction(dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac)
reconstructions.append(recon)
mu_map = find_attenuation_map(dicom_images, projection, reconstructions)
# for i, projection in enumerate(projections):
# _time = DICOMTime.Series.to_datetime(projection)
# print(f" - Projection: {projection.SeriesDescription:>10} at {DICOMTime.Study.to_datetime(projection)}, {DICOMTime.Series.to_datetime(projection)}, {DICOMTime.Content.to_datetime(projection)}, {DICOMTime.Acquisition.to_datetime(projection)}")
# recons = []
# for sc, ac in [(False, False), (False, True), (True, False), (True, True)]:
# r = find_reconstruction(dicom_images, projection, scatter_corrected=sc, attenuation_corrected=ac)
# print(f" - {r.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(r)}, {DICOMTime.Series.to_datetime(r)}, {DICOMTime.Content.to_datetime(r)}, {DICOMTime.Acquisition.to_datetime(r)}")
# recons.append(r)
# mu_map = find_attenuation_map(dicom_images, projection, recons)
# print(f" - {mu_map.SeriesDescription:>50} at {DICOMTime.Study.to_datetime(mu_map)}, {DICOMTime.Series.to_datetime(mu_map)}, {DICOMTime.Content.to_datetime(mu_map)}")
# print(f" -")
# extract pixel spacings and assert that they are equal for all reconstruction images
_map_lists = map(
lambda image: [*image.PixelSpacing, image.SliceThickness],
[*reconstructions, mu_map],

Tamino Huxohl
committed
_map_lists = map(
lambda pixel_spacing: list(map(float, pixel_spacing)),
_map_lists,

Tamino Huxohl
committed
_map_ndarrays = map(
lambda pixel_spacing: np.array(pixel_spacing), _map_lists
)
pixel_spacings = list(_map_ndarrays)
_equal = all(
map(
lambda pixel_spacing: (
pixel_spacing == pixel_spacings[0]
).all(),
pixel_spacings,

Tamino Huxohl
committed
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
)
assert (
_equal
), f"Not all pixel spacings of the reconstructions are equal: {pixel_spacings}"
pixel_spacing = pixel_spacings[0]
# use the shape with the fewest slices, all other images will be aligned to that
_map_lists = map(
lambda image: [
image.Rows,
image.Columns,
image.NumberOfSlices,
],
[*reconstructions, mu_map],
)
_map_lists = map(
lambda shape: list(map(int, shape)), _map_lists
)
_map_ndarrays = map(lambda shape: np.array(shape), _map_lists)
shapes = list(_map_ndarrays)
shapes.sort(key=lambda shape: shape[2])
shape = shapes[0]
# extract and sort energy windows
energy_windows = (
projection.EnergyWindowInformationSequence
)
energy_windows = map(
lambda ew: ew.EnergyWindowRangeSequence[0], energy_windows
)
energy_windows = map(
lambda ew: (
float(ew.EnergyWindowLowerLimit),
float(ew.EnergyWindowUpperLimit),
),
energy_windows,
)
energy_windows = list(energy_windows)
energy_windows.sort(key=lambda ew: ew[0], reverse=True)
row = {
headers.id: _id,
headers.patient_id: projection.PatientID,
headers.age: parse_age(projection.PatientAge),
headers.weight: float(projection.PatientWeight),
headers.size: float(projection.PatientSize),
headers.protocol: protocol,
headers.datetime_acquisition: DICOMTime.Series.to_datetime(projection),
headers.datetime_reconstruction: DICOMTime.Series.to_datetime(reconstructions[0]),
headers.pixel_spacing_x: pixel_spacing[0],
headers.pixel_spacing_y: pixel_spacing[1],
headers.pixel_spacing_z: pixel_spacing[2],
headers.shape_x: shape[0],
headers.shape_y: shape[1],
headers.shape_z: shape[2],
headers.radiopharmaceutical: projection.RadiopharmaceuticalInformationSequence[0].Radiopharmaceutical,
headers.radionuclide_dose: projection.RadiopharmaceuticalInformationSequence[0].RadionuclideTotalDose,
headers.radionuclide_code: projection.RadiopharmaceuticalInformationSequence[0].RadionuclideCodeSequence[0].CodeValue,
headers.radionuclide_meaning: projection.RadiopharmaceuticalInformationSequence[0].RadionuclideCodeSequence[0].CodeMeaning,
headers.energy_window_peak_lower: energy_windows[0][0],
headers.energy_window_peak_upper: energy_windows[0][1],
headers.energy_window_scatter_lower: energy_windows[1][0],
headers.energy_window_scatter_upper: energy_windows[1][1],
headers.detector_count: len(projection.DetectorInformationSequence),
headers.collimator_type: projection.DetectorInformationSequence[0].CollimatorType,
headers.rotation_start: float(projection.RotationInformationSequence[0].StartAngle),
headers.rotation_step: float(projection.RotationInformationSequence[0].AngularStep),
headers.rotation_scan_arc: float(projection.RotationInformationSequence[0].ScanArc),
}
_filename_base = f"{_id:04d}-{protocol.lower()}"
_ext = "dcm"
img_headers = [headers.file_projection, *recon_headers, headers.file_mu_map]
img_postfixes = ["projection", *recon_postfixes, "mu_map"]
images = [projection, *reconstructions, mu_map]
for header, image, postfix in zip(img_headers, images, img_postfixes):
_image = pydicom.dcmread(image.filename)
filename = f"{_filename_base}-{postfix}.{_ext}"
pydicom.dcmwrite(
os.path.join(args.images_dir, filename), _image

Tamino Huxohl
committed
row[header] = filename

Tamino Huxohl
committed
_id += 1
row = pd.DataFrame(row, index=[0])
data = pd.concat((data, row), ignore_index=True)
data.to_csv(args.meta_csv, index=False)
except Exception as e:
logger.error(e)
raise e