Newer
Older
import torch
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.logging import get_logger
"""
Since DICOM images only allow images stored in short integer format,
the Siemens scanner software multiplies values by a factor before storing
so that no precision is lost.
The scale can be found in this private DICOM tag.
"""
DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038
def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
"""
Align one image to another on the first axis (z-axis).
It is assumed that the second image has less slices than the first.
Then, the first image is shortened in a way that the centers of both images lie on top of each other.
:param image_1: the image to be aligned
:param image_2: the image to which image_1 is aligned
:return: the aligned image_1
"""
assert (
image_1.shape[0] > image_2.shape[0]
), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"
# central slice of image 2
c_2 = image_2.shape[0] // 2
# image to the left and right of the center
left = c_2
right = image_2.shape[0] - c_2
# central slice of image 1
c_1 = image_1.shape[0] // 2
# select center and same amount to the left/right as image_2
return image_1[(c_1 - left) : (c_1 + right)]
class MuMapDataset(Dataset):
def __init__(
self,
dataset_dir: str,
csv_file: str = "meta.csv",
split_file: str = "split.csv",
split_name: str = None,
images_dir: str = "images",
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True,
scatter_correction: bool = False,
transform_normalization: Transform = Transform(),
transform_augmentation: Transform = Transform(),
logger=None,
self.dir = dataset_dir
self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file)
self.split_file = os.path.join(dataset_dir, split_file)
self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation
self.logger = logger if logger is not None else get_logger()
self.bed_contours_file = (
os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
)
self.bed_contours = (
load_contours(self.bed_contours_file) if bed_contours_file else None
)
# read CSV file and from that access DICOM files
self.table = pd.read_csv(self.csv_file)
if split_name:
self.table = split_csv(self.table, self.split_file)[split_name]
self.table[headers.id] = self.table[headers.id].apply(int)
self.discard_mu_map_slices = discard_mu_map_slices
self.scatter_correction = scatter_correction
self.header_recon = (
headers.file_recon_nac_sc
if self.scatter_correction
else headers.file_recon_nac_nsc
)
self.reconstructions = {}
self.mu_maps = {}
self.pre_load_images()
def pre_load_images(self):
self.logger.debug("Pre-loading images ...")
_id = row[headers.id]
mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
mu_map = pydicom.dcmread(mu_map_file)
mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value
if self.discard_mu_map_slices:
mu_map = discard_slices(row, mu_map)
if self.bed_contours:
if _id in self.bed_contours:
bed_contour = self.bed_contours[_id]
for i in range(mu_map.shape[0]):
mu_map[i] = cv.drawContours(
mu_map[i], [bed_contour], -1, 0.0, -1
)
else:
logger.warning(f"Could not find bed contour for id {_id}")
recon_file = os.path.join(self.dir_images, row[self.header_recon])
recon = pydicom.dcmread(recon_file)
recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value
if self.align:
recon = align_images(recon, mu_map)
mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0)
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
self.logger.debug("Pre-loading images done!")
def __getitem__(self, index: int):
row = self.table.iloc[index]
_id = row[headers.id]
return self.getitem_by_id(_id)
def getitem_by_id(self, _id: int):
recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id]
recon, mu_map = self.transform_augmentation(recon, mu_map)

Tamino Huxohl
committed
def main(dataset):
from mu_map.util import to_grayscale, COLOR_WHITE
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText(
_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
)
def combine_images(images, slices):
image_1 = to_display_image(images[0], slices[0])
image_2 = to_display_image(images[1], slices[1])
space = np.full((image_1.shape[0], 10), 239, np.uint8)
return np.hstack((image_1, space, image_2))
for i in range(len(dataset)):
ir = 0
im = 0
row = dataset.table.iloc[i]
_id = row[headers.id]
recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r")
cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
key = cv.waitKey(timeout)
while True:
ir = (ir + 1) % recon.shape[0]
im = (im + 1) % mu_map.shape[0]
to_show = combine_images((recon, mu_map), (ir, im))
cv.imshow(wname, to_show)
elif key == ord("p"):
timeout = 0 if timeout > 0 else 100
ir = max(ir - 2, 0)
im = max(im - 2, 0)
elif key == ord("s"):
cv.imwrite(f"{running:03d}.png", to_show)
running += 1

Tamino Huxohl
committed
if __name__ == "__main__":
import argparse
from mu_map.logging import add_logging_args, get_logger_by_args
parser = argparse.ArgumentParser(
description="Visualize the images of a MuMapDataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dataset_dir", type=str, help="the directory from which the dataset is loaded"
)
parser.add_argument(
"--split",
type=str,
choices=["train", "validation", "test"],
help="choose the split of the data for the dataset",
)

Tamino Huxohl
committed
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
parser.add_argument(
"--unaligned",
action="store_true",
help="do not perform center alignment of reconstruction an mu-map slices",
)
parser.add_argument(
"--show_bed",
action="store_true",
help="do not remove the bed contour from the mu map",
)
parser.add_argument(
"--full_mu_map",
action="store_true",
help="do not remove broken slices of the mu map",
)
add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
args = parser.parse_args()
align = not args.unaligned
discard_mu_map_slices = not args.full_mu_map
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
logger = get_logger_by_args(args)
dataset = MuMapDataset(
args.dataset_dir,
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,

Tamino Huxohl
committed
logger=logger,
)
main(dataset)