Skip to content
Snippets Groups Projects
Commit 86e64fbe authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

Merge branch 'feature/hydra_config' of...

Merge branch 'feature/hydra_config' of https://gitlab.ub.uni-bielefeld.de/thuxohl/mu-map into feature/hydra_config
parents 546a8561 2635dfc3
No related branches found
No related tags found
No related merge requests found
......@@ -20,3 +20,6 @@ train_data/
build/
*.egg-info/
cgan_random_search*/
*.sbatch
"""
New dataset preparation script because the way DICOM data is exported has changed.
"""
import argparse
from enum import Enum
import os
from typing import Any, Dict, List
import pandas as pd
import pydicom
from mu_map.data.prepare import headers
from mu_map.file.dicom import DICOMTime, EXTENSION_DICOM, parse_age
from mu_map.logging import add_logging_args, get_logger_by_args
class Protocol(Enum):
"""
Class defining the two protocols __Rest__ & __Stress__ of MPI
SPECT studies.
"""
Rest = 1
Stress = 2
@classmethod
def get(cls, dcm: pydicom.dataset.Dataset):
"""
Get the protocol from a DICOM dataset.
This function is based on the study description.
Parameters
----------
dcm: pydicom.dataset.Dataset
the DICOM dataset
Returns
-------
Protocol
"""
if Protocol.Rest.name.lower() in dcm.SeriesDescription.lower():
return Protocol.Rest
elif Protocol.Stress.name.lower() in dcm.SeriesDescription.lower():
return Protocol.Stress
else:
raise ValueError(
f"Cannot extract protocol from SeriesDescription[{dcm.SeriesDescription}]"
)
def is_scatter_corrected(dcm: pydicom.dataset.Dataset) -> bool:
"""
Test if a DICOM dataset is a scatter-corrected reconstrucion based
on the __CorrectedImage__ tag.
Parameters
----------
dcm: pydicom.dataset.Dataset
the DICOM dataset
Returns
-------
bool
"""
try:
return "SCAT" in dcm.CorrectedImage
except AttributeError:
return False
def is_attenuation_corrected(dcm: pydicom.dataset.Dataset) -> bool:
"""
Test if a DICOM dataset is an attenuation-corrected reconstrucion based
on the __CorrectedImage__ tag.
Parameters
----------
dcm: pydicom.dataset.Dataset
the DICOM dataset
Returns
-------
bool
"""
try:
return "ATTN" in dcm.CorrectedImage
except AttributeError:
return False
def is_mu_map(dcm: pydicom.dataset.Dataset) -> bool:
"""
Test if a DICOM dataset is a mu map based on the
series description.
Parameters
----------
dcm: pydicom.dataset.Dataset
the DICOM dataset
Returns
-------
bool
"""
return "µ-map" in dcm.SeriesDescription
def closest_in_time(
ref: pydicom.dataset.Dataset,
dcms: List[pydicom.dataset.Dataset],
time_type: DICOMTime = DICOMTime.Series,
) -> pydicom.dataset.Dataset:
"""
Find the DICOM file closest in time to a reference file.
Parameters
----------
ref: pydicom.dataset.Dataset
the reference DICOM file
dcms: List[pydicom.dataset.Dataset]
the list which is searched
time_type: DICOMTime
the time type used for the search
Returns
-------
pydicom.dataset.Dataset
"""
ref_time = time_type.to_datetime(ref)
dcms.sort(key=lambda dcm: abs(time_type.to_datetime(dcm) - ref_time))
return dcms[0]
def get_recon_filename(dcm: pydicom.dataset.Dataset) -> str:
"""
Create a filename for a reconstruction depending on if it is
scatter- and/or attenuation-corrected.
Parameters
----------
dcm: pydicom.dataset.Dataset
the dataset for which a filename is created
Returns
-------
str
"""
sc_str = "sc" if is_scatter_corrected(dcm) else "nsc"
ac_str = "ac" if is_attenuation_corrected(dcm) else "nac"
return f"recon_{ac_str}_{sc_str}"
def create_row(
_id: int, dcm: pydicom.dataset.Dataset, protocol: Protocol
) -> Dict[str, Any]:
"""
Create a row for the meta data table.
Parameters
----------
_id: int
the id of the row
dcm: pydicom.dataset.Dataset
the DICOM file to extract meta information
protocol: Protocol
the protocol of the study
Returns
-------
Dict[str, Any]
"""
return {
headers.id: _id,
headers.patient_id: dcm.PatientID,
headers.age: parse_age(dcm.PatientAge),
headers.sex: dcm.PatientSex,
headers.weight: float(dcm.PatientWeight),
headers.size: float(dcm.PatientSize),
headers.protocol: protocol.name.lower(),
headers.datetime_acquisition: DICOMTime.Acquisition.to_datetime(dcm),
}
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="Prepare a dataset from DICOM directories",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--dicom_dir",
type=str,
required=True,
help="directory containing the raw DICOM files to be processed - it is expected to contain directories and each directory can contain multiple studies of the same patient",
)
parser.add_argument(
"--dataset_dir",
type=str,
required=True,
help="directory where images, meta-information and the logs are stored",
)
parser.add_argument(
"--images_dir",
type=str,
default="images",
help="sub-directory of --dataset_dir where images are stored",
)
parser.add_argument(
"--meta_csv",
type=str,
default="meta.csv",
help="CSV file under --dataset_dir where meta-information is stored",
)
add_logging_args(
parser, defaults={"--logfile": "prepare.log", "--loglevel": "DEBUG"}
)
args = parser.parse_args()
args.images_dir = os.path.join(args.dataset_dir, args.images_dir)
args.meta_csv = os.path.join(args.dataset_dir, args.meta_csv)
args.logfile = os.path.join(args.dataset_dir, args.logfile)
if not os.path.exists(args.dataset_dir):
os.mkdir(args.dataset_dir)
if not os.path.exists(args.images_dir):
os.mkdir(args.images_dir)
global logger
logger = get_logger_by_args(args)
patient_dirs = sorted(os.listdir(args.dicom_dir))
data = pd.DataFrame({})
current_id = 1
for patient_dir in patient_dirs:
logger.info(f"Process directory: {patient_dir}")
patient_dir = os.path.join(args.dicom_dir, patient_dir)
files = sorted([os.path.join(patient_dir, f) for f in os.listdir(patient_dir)])
dcms = list(map(lambda f: pydicom.dcmread(f), files))
files = dict(zip(map(lambda dcm: dcm.SeriesInstanceUID, dcms), files))
for protocol in Protocol:
_dcms = list(filter(lambda dcm: Protocol.get(dcm) is protocol, dcms))
dcms_sc_ac = list(
filter(
lambda dcm: is_scatter_corrected(dcm)
and is_attenuation_corrected(dcm),
_dcms,
)
)
if len(dcms_sc_ac) == 0:
logger.info(f" - Protocol {protocol.name} not available!")
continue
logger.info(f" - Process protocol: {protocol.name}")
for dcm_sc_ac in dcms_sc_ac:
logger.info(f" - Create row with id: {current_id:04d}")
dcm_sc_nac = closest_in_time(
dcm_sc_ac,
list(
filter(
lambda dcm: is_scatter_corrected(dcm)
and not is_attenuation_corrected(dcm),
_dcms,
)
),
)
dcm_nsc_ac = closest_in_time(
dcm_sc_ac,
list(
filter(
lambda dcm: not is_scatter_corrected(dcm)
and is_attenuation_corrected(dcm),
_dcms,
)
),
)
dcm_nsc_nac = closest_in_time(
dcm_sc_ac,
list(
filter(
lambda dcm: not is_scatter_corrected(dcm)
and not is_attenuation_corrected(dcm),
_dcms,
)
),
)
mu_map = closest_in_time(
dcm_sc_ac, list(filter(lambda dcm: is_mu_map(dcm), _dcms))
)
row = create_row(current_id, dcm_sc_ac, protocol)
recons = [dcm_sc_ac, dcm_sc_nac, dcm_nsc_ac, dcm_nsc_nac]
_files = list(map(lambda recon: get_recon_filename(recon), recons))
for recon, f in zip(recons, _files):
_f = f"{current_id:04d}-{f}{EXTENSION_DICOM}"
row[f"file_{f}"] = _f
_f = os.path.join(args.images_dir, _f)
logger.info(f" - Store {files[recon.SeriesInstanceUID]} at {_f}")
pydicom.dcmwrite(_f, recon)
_f = f"{current_id:04d}-mu_map{EXTENSION_DICOM}"
row[headers.file_mu_map] = _f
_f = os.path.join(args.images_dir, _f)
logger.info(f" - Store {files[mu_map.SeriesInstanceUID]} at {_f}")
pydicom.dcmwrite(_f, mu_map)
row = pd.DataFrame(row, index=[0])
data = pd.concat((data, row), ignore_index=True)
current_id = current_id + 1
data.to_csv(args.meta_csv, index=False)
......@@ -121,8 +121,12 @@ class DICOMTime(Enum):
"""
Get the datetime according to this DICOMTime type.
"""
_date = dcm_header[self.date_field()].value
_time = dcm_header[self.time_field()].value
try:
_date = dcm_header[self.date_field()].value
_time = dcm_header[self.time_field()].value
except:
return " None"
return datetime(
year=int(_date[0:4]),
month=int(_date[4:6]),
......
"""
Print stats about outliers labeled with `label_outliers`.
"""
import argparse
import pandas as pd
from mu_map.random_search.eval.params import parameter_groups
from mu_map.random_search.eval.util import filter_by_params, load_data, remove_outliers
data = load_data("cgan_random_search/")
parser = argparse.ArgumentParser(
description="Evaluate the amount of outliers (per parameter) of a random search run",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--random_search_dir",
type=str,
default="cgan_random_search",
help="the directory containing the random search data",
)
parser.add_argument(
"--outliers_file",
type=str,
default="outliers.csv",
help="file unter <random_search_dir> containing the labeling of outlier runs",
)
args = parser.parse_args()
data = load_data(args.random_search_dir)
n_total = len(data)
outlier_runs = pd.read_csv("cgan_random_search/outliers.csv")
......@@ -30,7 +50,6 @@ for param_label, param_groups in parameter_groups.items():
for label, value in param_groups.groups.items():
n_outlier = len(filter_by_params(data_outlier, value, param_groups.keys))
n_total = len(filter_by_params(data, value, param_groups.keys))
print(
f" - {label:>12}: {str(n_outlier):>2}/{str(n_total):>2} = {100 * n_outlier / n_total:.2f}%"
)
percent = f"{100 * n_outlier / n_total:.2f}%" if n_total > 0 else " NaN"
print(f" - {label:>12}: {str(n_outlier):>3}/{str(n_total):>3} = {percent}")
print()
......@@ -12,11 +12,12 @@ import pandas as pd
from mu_map.random_search.cgan import load_params
SIZE_DEFAULT = 12
SIZE_DEFAULT = 16
plt.rc("font", family="Roboto") # controls default font
plt.rc("font", weight="normal") # controls default font
plt.rc("font", size=SIZE_DEFAULT) # controls default text sizes
plt.rc("axes", titlesize=18) # fontsize of the axes title
# plt.rc("axes", titlesize=18) # fontsize of the axes title
plt.rc("axes", titlesize=22) # fontsize of the axes title
class ColorList:
......@@ -112,9 +113,9 @@ def load_data(
a dict mapping the iteration number to "measures", "params", and "dir"
"""
dirs_run = sorted(os.listdir(dir_random_search))
dirs_run = filter(lambda f: f.isdigit(), dirs_run)
dirs_run = map(lambda f: os.path.join(dir_random_search, f), dirs_run)
dirs_run = filter(lambda f: os.path.isdir(f), dirs_run)
dirs_run = filter(lambda f: not os.path.islink(f), dirs_run)
dirs_run = map(lambda f: os.path.basename(f), dirs_run)
data = {}
......@@ -239,7 +240,7 @@ class TablePrinter:
for header, column in table.items()
]
)
lenghtes = dict(
lengthes = dict(
[
(header, max(len(header), *map(len, column)))
for header, column in table.items()
......@@ -247,19 +248,19 @@ class TablePrinter:
)
line_top = f"{self.vert}{self.t_down}{self.vert}".join(
map(lambda header: self.vert * lenghtes[header], headers)
map(lambda header: self.vert * lengthes[header], headers)
)
line_top = self.top_left + self.vert + line_top + self.vert + self.top_right
print(line_top)
line_headers = f" {self.hori} ".join(
map(lambda header: f"{header:>{lenghtes[header]}}", table.keys())
map(lambda header: f"{header:>{lengthes[header]}}", table.keys())
)
line_headers = self.hori + " " + line_headers + " " + self.hori
print(line_headers)
line_mid = f"{self.vert}{self.cross}{self.vert}".join(
map(lambda header: self.vert * lenghtes[header], headers)
map(lambda header: self.vert * lengthes[header], headers)
)
line_mid = self.t_left + self.vert + line_mid + self.vert + self.t_right
print(line_mid)
......@@ -267,7 +268,7 @@ class TablePrinter:
for i in range(len(table[headers[0]])):
values = map(
lambda header: self.color(
f"{table[header][i]:>{lenghtes[header]}}", header
f"{table[header][i]:>{lengthes[header]}}", header
),
headers,
)
......@@ -276,7 +277,7 @@ class TablePrinter:
print(line)
line_bot = f"{self.vert}{self.t_up}{self.vert}".join(
map(lambda header: self.vert * lenghtes[header], headers)
map(lambda header: self.vert * lengthes[header], headers)
)
line_bot = (
self.bottom_left + self.vert + line_bot + self.vert + self.bottom_right
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment