Skip to content
Snippets Groups Projects
Commit 0998d095 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

use new functions in mu map dataset and implement lazy loading

parent 52d01622
No related branches found
No related tags found
No related merge requests found
import os
from typing import Optional
from typing import Optional, Tuple
import cv2 as cv
import pandas as pd
......@@ -9,48 +9,18 @@ import torch
from torch.utils.data import Dataset
from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
from mu_map.data.remove_bed import (
DEFAULT_BED_CONTOURS_FILENAME,
load_contours,
remove_bed,
)
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.dataset.util import align_images, load_dcm_img
from mu_map.logging import get_logger
"""
Since DICOM images only allow images stored in short integer format,
the Siemens scanner software multiplies values by a factor before storing
so that no precision is lost.
The scale can be found in this private DICOM tag.
"""
DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038
def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
"""
Align one image to another on the first axis (z-axis).
It is assumed that the second image has less slices than the first.
Then, the first image is shortened in a way that the centers of both images lie on top of each other.
:param image_1: the image to be aligned
:param image_2: the image to which image_1 is aligned
:return: the aligned image_1
"""
assert (
image_1.shape[0] > image_2.shape[0]
), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"
# central slice of image 2
c_2 = image_2.shape[0] // 2
# image to the left and right of the center
left = c_2
right = image_2.shape[0] - c_2
# central slice of image 1
c_1 = image_1.shape[0] // 2
# select center and same amount to the left/right as image_2
return image_1[(c_1 - left) : (c_1 + right)]
class MuMapDataset(Dataset):
def __init__(
self,
......@@ -102,48 +72,39 @@ class MuMapDataset(Dataset):
self.reconstructions = {}
self.mu_maps = {}
self.pre_load_images()
def pre_load_images(self):
self.logger.debug("Pre-loading images ...")
for i in range(len(self.table)):
row = self.table.iloc[i]
_id = row[headers.id]
mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
mu_map = pydicom.dcmread(mu_map_file)
mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value
if self.discard_mu_map_slices:
mu_map = discard_slices(row, mu_map)
if self.bed_contours:
if _id in self.bed_contours:
bed_contour = self.bed_contours[_id]
for i in range(mu_map.shape[0]):
mu_map[i] = cv.drawContours(
mu_map[i], [bed_contour], -1, 0.0, -1
)
else:
logger.warning(f"Could not find bed contour for id {_id}")
recon_file = os.path.join(self.dir_images, row[self.header_recon])
recon = pydicom.dcmread(recon_file)
recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value
if self.align:
recon = align_images(recon, mu_map)
mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0)
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
self.logger.debug("Pre-loading images done!")
def load_image(self, _id: int):
row = self.table[self.table[headers.id] == _id].iloc[0]
_id = row[headers.id]
mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
mu_map = load_dcm_img(mu_map_file)
if self.discard_mu_map_slices:
mu_map = discard_slices(row, mu_map)
if self.bed_contours:
if _id in self.bed_contours:
bed_contour = self.bed_contours[_id]
mu_map = remove_bed(mu_map, bed_contour)
else:
logger.warning(f"Could not find bed contour for id {_id}")
recon_file = os.path.join(self.dir_images, row[self.header_recon])
recon = load_dcm_img(recon_file)
if self.align:
recon, mu_map = align_images(recon, mu_map)
mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0)
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
def __getitem__(self, index: int):
row = self.table.iloc[index]
......@@ -151,6 +112,9 @@ class MuMapDataset(Dataset):
return self.getitem_by_id(_id)
def getitem_by_id(self, _id: int):
if _id not in self.reconstructions:
self.load_image(_id)
recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id]
......@@ -165,7 +129,7 @@ class MuMapDataset(Dataset):
__all__ = [MuMapDataset.__name__]
def main(dataset):
def main(dataset, ids, paused=False):
from mu_map.util import to_grayscale, COLOR_WHITE
wname = "Dataset"
......@@ -173,7 +137,10 @@ def main(dataset):
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
timeout = 100
TIMEOUT_PAUSED = 0
TIMEOUT_RUNNING = 1000 // 15
timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
......@@ -187,8 +154,16 @@ def main(dataset):
def combine_images(images, slices):
image_1 = to_display_image(images[0], slices[0])
image_2 = to_display_image(images[1], slices[1])
space = np.full((image_1.shape[0], 10), 239, np.uint8)
return np.hstack((image_1, space, image_2))
image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))
image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
image_3_1 = image_2.copy()
image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)
space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
return np.hstack((image_1, space, image_3, space, image_2))
for i in range(len(dataset)):
ir = 0
......@@ -197,6 +172,9 @@ def main(dataset):
row = dataset.table.iloc[i]
_id = row[headers.id]
if ids is not None and _id not in ids:
continue
recon, mu_map = dataset[i]
recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
......@@ -219,11 +197,18 @@ def main(dataset):
elif key == ord("q"):
exit(0)
elif key == ord("p"):
timeout = 0 if timeout > 0 else 100
timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING
elif key == 82: # up arrow key
ir = ir - 1
continue
elif key == 83: # right arrow key
im = im - 1
continue
elif key == 81: # left arrow key
im = im - 1
ir = max(ir - 2, 0)
elif key == 84: # down arrow key
ir = ir - 1
im = max(im - 2, 0)
elif key == ord("s"):
cv.imwrite(f"{running:03d}.png", to_show)
......@@ -233,6 +218,7 @@ def main(dataset):
if __name__ == "__main__":
import argparse
from mu_map.dataset.transform import PadCropTranform
from mu_map.logging import add_logging_args, get_logger_by_args
parser = argparse.ArgumentParser(
......@@ -263,6 +249,22 @@ if __name__ == "__main__":
action="store_true",
help="do not remove broken slices of the mu map",
)
parser.add_argument(
"--ids",
type=int,
nargs="*",
help="only display certain ids",
)
parser.add_argument(
"--paused",
action="store_true",
help="start in paused mode",
)
parser.add_argument(
"--pad_crop",
type=int,
help="pad crop images to this size",
)
add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
args = parser.parse_args()
......@@ -271,12 +273,17 @@ if __name__ == "__main__":
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
logger = get_logger_by_args(args)
transform_normalization = (
PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
)
dataset = MuMapDataset(
args.dataset_dir,
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,
split_name=args.split,
transform_normalization=transform_normalization,
logger=logger,
)
main(dataset)
main(dataset, args.ids, paused=args.paused)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment