Newer
Older
from typing import List, Optional, Tuple
import torch
from mu_map.data.remove_bed import (
DEFAULT_BED_CONTOURS_FILENAME,
load_contours,
remove_bed,
)
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.dataset.util import align_images
from mu_map.file.dicom import load_dcm_img
"""
A dataset to map reconstructions to attenuation maps (mu maps).
The dataset is lazy. This means that that dataset creation is
fast and images are only read into memory when their first accessed.
Thus, after the first iteration, accessing images becomes a lot
faster.
"""
self,
dataset_dir: str,
csv_file: str = "meta.csv",
split_name: Optional[str] = None,
images_dir: str = "images",
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True,
scatter_correction: bool = False,
transform_normalization: Transform = Transform(),
transform_augmentation: Transform = Transform(),
logger=None,
"""
Create a new mu map dataset.
Parameters
----------
dataset_dir: str
directory of the dataset to be loaded
csv_file: str
name of the csv file in the dataset directory containing meta information (created by mu_map.data.prepare)
split_file: str
csv file defining a split of the dataset in train/validation/test (created by mu_map.data.split)
split_name: str, optional
the name of the split which is loaded
images_dir: str
directory under `dataset_dir` containing the actual images in DICOM format
bed_contours_file: str, optional
json file containing contours around the bed for each mu map (see mu_map.data.remove_bed)
discard_mu_map_slices: bool
remove defective slices from mu maps (have to be labeled by mu_map.data.review_mu_map)
align: bool
center align reconstructions and mu maps
scatter_correction: bool
use scatter corrected reconstructions
transform_normalization: Transform
transform used for image normalization which is applied once when the image is loaded
transform_augmentation: Transform
transform used for augmentation which is applied every time `__getitem__` is called
logger: Logger, optional
"""
self.dir = dataset_dir
self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file)
self.split_file = os.path.join(dataset_dir, split_file)
self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation
self.logger = (
logger if logger is not None else get_logger(name=MuMapDataset.__name__)
)
self.bed_contours_file = (
os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
)
self.bed_contours = (
load_contours(self.bed_contours_file) if bed_contours_file else None
)
# read CSV file and from that access DICOM files
self.table = pd.read_csv(self.csv_file)
if split_name:
self.table = split_csv(self.table, self.split_file)[split_name]
self.table[headers.id] = self.table[headers.id].apply(int)
self.discard_mu_map_slices = discard_mu_map_slices
self.scatter_correction = scatter_correction
self.header_recon = (
headers.file_recon_nac_sc
if self.scatter_correction
else headers.file_recon_nac_nsc
)
self.reconstructions = {}
self.mu_maps = {}
def copy(self, split_name: str, **kwargs) -> MuMapDataset:
"""
Create a copy of the dataset and modify parameters.
Parameters
----------
split_name: str
the split which with which the copy is created
kwargs:
Modify parameters by name.
Currently, only `bed_contours_file` is supported.
if "bed_contours_file" not in kwargs:
kwargs["bed_contours_file"] = os.path.basename(self.bed_contours_file)
return MuMapDataset(
dataset_dir=self.dir,
csv_file=os.path.basename(self.csv_file),
split_file=os.path.basename(self.split_file),
split_name=split_name,
images_dir=os.path.basename(self.dir_images),
bed_contours_file=kwargs["bed_contours_file"],
discard_mu_map_slices=self.discard_mu_map_slices,
align=self.align,
scatter_correction=self.scatter_correction,
transform_normalization=self.transform_normalization,
transform_augmentation=self.transform_augmentation,
logger=self.logger,
)
def load_image(self, _id: int):
"""
Load an image into memory.
This function also performs all of the pre-processing (discard slices, remove bed, alignment ...).
Afterwards, the reconstruction and mu map are available from the local
dicts: `self.reconstructions` and `self.mu_maps.`
Parameters
----------
_id: int
the id of the image to be loaded
"""
row = self.table[self.table[headers.id] == _id].iloc[0]
_id = row[headers.id]
mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
mu_map = load_dcm_img(mu_map_file, direction=1)
recon_file = os.path.join(self.dir_images, row[self.header_recon])
recon = load_dcm_img(recon_file, direction=1)
if self.discard_mu_map_slices:
mu_map, recon = discard_slices(row, mu_map, recon)
if self.bed_contours:
if _id in self.bed_contours:
bed_contour = self.bed_contours[_id]
mu_map = remove_bed(mu_map, bed_contour)
else:
logger.warning(f"Could not find bed contour for id {_id}")
if self.align:
recon, mu_map = align_images(recon, mu_map)
mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0)
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
"""
Load all images into memory.
"""
for _id in self.table[headers.id]:
self.load_image(_id)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a reconstruction mu map pair by index.
This method retrieves the image id of this index and call `get_item_by_id`.
Parameters
----------
index: int
"""
_id = row[headers.id]
return self.get_item_by_id(_id)
def get_item_by_id(self, _id: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get a reconstruction and mu map pair by their id.
This methods loads the images of not yet in memory and applies the
augmentation transform before returning them.
Parameters
----------
_id: int
Returns
-------
Tuple[torch.Tensor, torch.Tensor]
a pair of a reconstruction and the according mu map
"""
if _id not in self.reconstructions:
self.load_image(_id)
recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id]
recon, mu_map = self.transform_augmentation(recon, mu_map)
def __len__(self) -> int:
"""
Get the number of elements in this dataset.
"""
def main(dataset: MuMapDataset, ids: Optional[List[int]] = None, paused: bool=False):
"""
Display reconstructions and mu maps in a dataset.
Parameters
----------
dataset: MuMapDataset
the dataset of which elements are displayed
ids: list of int, optional
only display these ids
paused: bool
start display in paused mode
"""
from mu_map.util import to_grayscale, COLOR_WHITE
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
timeout_paused = 0
timeout_running = 1000 // 15
timeout = timeout_paused if paused else timeout_running
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText(
_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
)
def combine_images(images, slices):
image_1 = to_display_image(images[0], slices[0])
image_2 = to_display_image(images[1], slices[1])
image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))
image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
image_3_1 = image_2.copy()
image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)
space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
return np.hstack((image_1, space, image_3, space, image_2))
for i in range(len(dataset)):
ir = 0
im = 0
row = dataset.table.iloc[i]
_id = row[headers.id]
if ids is not None and _id not in ids:
continue
recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r")
cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
key = cv.waitKey(timeout)
while True:
ir = (ir + 1) % recon.shape[0]
im = (im + 1) % mu_map.shape[0]
to_show = combine_images((recon, mu_map), (ir, im))
cv.imshow(wname, to_show)
timeout = timeout_paused if timeout > 0 else timeout_running
elif key == 82: # up arrow key
ir = ir - 1
continue
im = im - 1
im = im - 1
elif key == 84: # down arrow key
ir = ir - 1
elif key == ord("s"):
cv.imwrite(f"{running:03d}.png", to_show)
running += 1

Tamino Huxohl
committed
if __name__ == "__main__":
import argparse
from mu_map.dataset.transform import PadCropTranform

Tamino Huxohl
committed
from mu_map.logging import add_logging_args, get_logger_by_args
parser = argparse.ArgumentParser(
description="Visualize the images of a MuMapDataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dataset_dir", type=str, help="the directory from which the dataset is loaded"
)
parser.add_argument(
"--split",
type=str,
choices=["train", "validation", "test"],
help="choose the split of the data for the dataset",
)

Tamino Huxohl
committed
parser.add_argument(
"--unaligned",
action="store_true",
help="do not perform center alignment of reconstruction an mu-map slices",
)
parser.add_argument(
"--show_bed",
action="store_true",
help="do not remove the bed contour from the mu map",
)
parser.add_argument(
"--full_mu_map",
action="store_true",
help="do not remove broken slices of the mu map",
)
parser.add_argument(
"--ids",
type=int,
nargs="*",
help="only display certain ids",
)
parser.add_argument(
"--paused",
action="store_true",
help="start in paused mode",
)
parser.add_argument(
"--pad_crop",
type=int,
help="pad crop images to this size",
)

Tamino Huxohl
committed
add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
args = parser.parse_args()
align = not args.unaligned
discard_mu_map_slices = not args.full_mu_map
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
logger = get_logger_by_args(args)
transform_normalization = (
PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
)

Tamino Huxohl
committed
dataset = MuMapDataset(
args.dataset_dir,
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,
transform_normalization=transform_normalization,

Tamino Huxohl
committed
logger=logger,
)
main(dataset, args.ids, paused=args.paused)