-
Tamino Huxohl authoredTamino Huxohl authored
default.py 8.89 KiB
import os
from typing import Optional
import cv2 as cv
import pandas as pd
import pydicom
import numpy as np
import torch
from torch.utils.data import Dataset
from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.logging import get_logger
"""
Since DICOM images only allow images stored in short integer format,
the Siemens scanner software multiplies values by a factor before storing
so that no precision is lost.
The scale can be found in this private DICOM tag.
"""
DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038
def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
"""
Align one image to another on the first axis (z-axis).
It is assumed that the second image has less slices than the first.
Then, the first image is shortened in a way that the centers of both images lie on top of each other.
:param image_1: the image to be aligned
:param image_2: the image to which image_1 is aligned
:return: the aligned image_1
"""
assert (
image_1.shape[0] > image_2.shape[0]
), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"
# central slice of image 2
c_2 = image_2.shape[0] // 2
# image to the left and right of the center
left = c_2
right = image_2.shape[0] - c_2
# central slice of image 1
c_1 = image_1.shape[0] // 2
# select center and same amount to the left/right as image_2
return image_1[(c_1 - left) : (c_1 + right)]
class MuMapDataset(Dataset):
def __init__(
self,
dataset_dir: str,
csv_file: str = "meta.csv",
split_file: str = "split.csv",
split_name: str = None,
images_dir: str = "images",
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True,
align: bool = True,
scatter_correction: bool = False,
transform_normalization: Transform = Transform(),
transform_augmentation: Transform = Transform(),
logger=None,
):
super().__init__()
self.dir = dataset_dir
self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file)
self.split_file = os.path.join(dataset_dir, split_file)
self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation
self.logger = logger if logger is not None else get_logger()
self.bed_contours_file = (
os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
)
self.bed_contours = (
load_contours(self.bed_contours_file) if bed_contours_file else None
)
# read CSV file and from that access DICOM files
self.table = pd.read_csv(self.csv_file)
if split_name:
self.table = split_csv(self.table, self.split_file)[split_name]
self.table["id"] = self.table["id"].apply(int)
self.discard_mu_map_slices = discard_mu_map_slices
self.align = align
self.scatter_correction = scatter_correction
self.header_recon = headers.file_recon_nac_sc if self.scatter_correction else headers.file_recon_nac_nsc
self.reconstructions = {}
self.mu_maps = {}
self.pre_load_images()
def pre_load_images(self):
self.logger.debug("Pre-loading images ...")
for i in range(len(self.table)):
row = self.table.iloc[i]
_id = row["id"]
mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
mu_map = pydicom.dcmread(mu_map_file)
mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value
if self.discard_mu_map_slices:
mu_map = discard_slices(row, mu_map)
if self.bed_contours:
bed_contour = self.bed_contours[row["id"]]
for i in range(mu_map.shape[0]):
mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
recon_file = os.path.join(self.dir_images, row[self.header_recon])
recon = pydicom.dcmread(recon_file)
recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value
if self.align:
recon = align_images(recon, mu_map)
mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0)
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
self.logger.debug("Pre-loading images done!")
def __getitem__(self, index: int):
row = self.table.iloc[index]
_id = row["id"]
recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id]
recon, mu_map = self.transform_augmentation(recon, mu_map)
return recon, mu_map
def __len__(self):
return len(self.table)
__all__ = [MuMapDataset.__name__]
def main(dataset):
from mu_map.util import to_grayscale, COLOR_WHITE
wname = "Dataset"
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
timeout = 100
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText(
_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
)
return _image
def combine_images(images, slices):
image_1 = to_display_image(images[0], slices[0])
image_2 = to_display_image(images[1], slices[1])
space = np.full((image_1.shape[0], 10), 239, np.uint8)
return np.hstack((image_1, space, image_2))
for i in range(len(dataset)):
ir = 0
im = 0
recon, mu_map = dataset[i]
recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r")
cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
key = cv.waitKey(timeout)
running = 0
while True:
ir = (ir + 1) % recon.shape[0]
im = (im + 1) % mu_map.shape[0]
to_show = combine_images((recon, mu_map), (ir, im))
cv.imshow(wname, to_show)
key = cv.waitKey(timeout)
if key == ord("n"):
break
elif key == ord("q"):
exit(0)
elif key == ord("p"):
timeout = 0 if timeout > 0 else 100
elif key == 83: # right arrow key
continue
elif key == 81: # left arrow key
ir = max(ir - 2, 0)
im = max(im - 2, 0)
elif key == ord("s"):
cv.imwrite(f"{running:03d}.png", to_show)
running += 1
if __name__ == "__main__":
import argparse
from mu_map.logging import add_logging_args, get_logger_by_args
parser = argparse.ArgumentParser(
description="Visualize the images of a MuMapDataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dataset_dir", type=str, help="the directory from which the dataset is loaded"
)
parser.add_argument(
"--split",
type=str,
choices=["train", "validation", "test"],
help="choose the split of the data for the dataset",
)
parser.add_argument(
"--unaligned",
action="store_true",
help="do not perform center alignment of reconstruction an mu-map slices",
)
parser.add_argument(
"--show_bed",
action="store_true",
help="do not remove the bed contour from the mu map",
)
parser.add_argument(
"--full_mu_map",
action="store_true",
help="do not remove broken slices of the mu map",
)
add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
args = parser.parse_args()
align = not args.unaligned
discard_mu_map_slices = not args.full_mu_map
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
logger = get_logger_by_args(args)
dataset = MuMapDataset(
args.dataset_dir,
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,
split_name=args.split,
logger=logger,
)
main(dataset)