Skip to content
Snippets Groups Projects
Commit 0998d095 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

use new functions in mu map dataset and implement lazy loading

parent 52d01622
No related branches found
No related tags found
No related merge requests found
import os import os
from typing import Optional from typing import Optional, Tuple
import cv2 as cv import cv2 as cv
import pandas as pd import pandas as pd
...@@ -9,48 +9,18 @@ import torch ...@@ -9,48 +9,18 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from mu_map.data.prepare import headers from mu_map.data.prepare import headers
from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours from mu_map.data.remove_bed import (
DEFAULT_BED_CONTOURS_FILENAME,
load_contours,
remove_bed,
)
from mu_map.data.review_mu_map import discard_slices from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform from mu_map.dataset.transform import Transform
from mu_map.dataset.util import align_images, load_dcm_img
from mu_map.logging import get_logger from mu_map.logging import get_logger
"""
Since DICOM images only allow images stored in short integer format,
the Siemens scanner software multiplies values by a factor before storing
so that no precision is lost.
The scale can be found in this private DICOM tag.
"""
DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038
def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
"""
Align one image to another on the first axis (z-axis).
It is assumed that the second image has less slices than the first.
Then, the first image is shortened in a way that the centers of both images lie on top of each other.
:param image_1: the image to be aligned
:param image_2: the image to which image_1 is aligned
:return: the aligned image_1
"""
assert (
image_1.shape[0] > image_2.shape[0]
), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}"
# central slice of image 2
c_2 = image_2.shape[0] // 2
# image to the left and right of the center
left = c_2
right = image_2.shape[0] - c_2
# central slice of image 1
c_1 = image_1.shape[0] // 2
# select center and same amount to the left/right as image_2
return image_1[(c_1 - left) : (c_1 + right)]
class MuMapDataset(Dataset): class MuMapDataset(Dataset):
def __init__( def __init__(
self, self,
...@@ -102,48 +72,39 @@ class MuMapDataset(Dataset): ...@@ -102,48 +72,39 @@ class MuMapDataset(Dataset):
self.reconstructions = {} self.reconstructions = {}
self.mu_maps = {} self.mu_maps = {}
self.pre_load_images()
def load_image(self, _id: int):
def pre_load_images(self): row = self.table[self.table[headers.id] == _id].iloc[0]
self.logger.debug("Pre-loading images ...") _id = row[headers.id]
for i in range(len(self.table)):
row = self.table.iloc[i] mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
_id = row[headers.id] mu_map = load_dcm_img(mu_map_file)
if self.discard_mu_map_slices:
mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) mu_map = discard_slices(row, mu_map)
mu_map = pydicom.dcmread(mu_map_file) if self.bed_contours:
mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value if _id in self.bed_contours:
if self.discard_mu_map_slices: bed_contour = self.bed_contours[_id]
mu_map = discard_slices(row, mu_map) mu_map = remove_bed(mu_map, bed_contour)
if self.bed_contours: else:
if _id in self.bed_contours: logger.warning(f"Could not find bed contour for id {_id}")
bed_contour = self.bed_contours[_id]
for i in range(mu_map.shape[0]): recon_file = os.path.join(self.dir_images, row[self.header_recon])
mu_map[i] = cv.drawContours( recon = load_dcm_img(recon_file)
mu_map[i], [bed_contour], -1, 0.0, -1 if self.align:
) recon, mu_map = align_images(recon, mu_map)
else:
logger.warning(f"Could not find bed contour for id {_id}") mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
recon_file = os.path.join(self.dir_images, row[self.header_recon]) mu_map = mu_map.unsqueeze(dim=0)
recon = pydicom.dcmread(recon_file)
recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value recon = recon.astype(np.float32)
if self.align: recon = torch.from_numpy(recon)
recon = align_images(recon, mu_map) recon = recon.unsqueeze(dim=0)
mu_map = mu_map.astype(np.float32) recon, mu_map = self.transform_normalization(recon, mu_map)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0) self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
self.logger.debug("Pre-loading images done!")
def __getitem__(self, index: int): def __getitem__(self, index: int):
row = self.table.iloc[index] row = self.table.iloc[index]
...@@ -151,6 +112,9 @@ class MuMapDataset(Dataset): ...@@ -151,6 +112,9 @@ class MuMapDataset(Dataset):
return self.getitem_by_id(_id) return self.getitem_by_id(_id)
def getitem_by_id(self, _id: int): def getitem_by_id(self, _id: int):
if _id not in self.reconstructions:
self.load_image(_id)
recon = self.reconstructions[_id] recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id] mu_map = self.mu_maps[_id]
...@@ -165,7 +129,7 @@ class MuMapDataset(Dataset): ...@@ -165,7 +129,7 @@ class MuMapDataset(Dataset):
__all__ = [MuMapDataset.__name__] __all__ = [MuMapDataset.__name__]
def main(dataset): def main(dataset, ids, paused=False):
from mu_map.util import to_grayscale, COLOR_WHITE from mu_map.util import to_grayscale, COLOR_WHITE
wname = "Dataset" wname = "Dataset"
...@@ -173,7 +137,10 @@ def main(dataset): ...@@ -173,7 +137,10 @@ def main(dataset):
cv.resizeWindow(wname, 1600, 900) cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8) space = np.full((1024, 10), 239, np.uint8)
timeout = 100 TIMEOUT_PAUSED = 0
TIMEOUT_RUNNING = 1000 // 15
timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING
def to_display_image(image, _slice): def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
...@@ -187,8 +154,16 @@ def main(dataset): ...@@ -187,8 +154,16 @@ def main(dataset):
def combine_images(images, slices): def combine_images(images, slices):
image_1 = to_display_image(images[0], slices[0]) image_1 = to_display_image(images[0], slices[0])
image_2 = to_display_image(images[1], slices[1]) image_2 = to_display_image(images[1], slices[1])
space = np.full((image_1.shape[0], 10), 239, np.uint8)
return np.hstack((image_1, space, image_2)) image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))
image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
image_3_1 = image_2.copy()
image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)
space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
return np.hstack((image_1, space, image_3, space, image_2))
for i in range(len(dataset)): for i in range(len(dataset)):
ir = 0 ir = 0
...@@ -197,6 +172,9 @@ def main(dataset): ...@@ -197,6 +172,9 @@ def main(dataset):
row = dataset.table.iloc[i] row = dataset.table.iloc[i]
_id = row[headers.id] _id = row[headers.id]
if ids is not None and _id not in ids:
continue
recon, mu_map = dataset[i] recon, mu_map = dataset[i]
recon = recon.squeeze().numpy() recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy() mu_map = mu_map.squeeze().numpy()
...@@ -219,11 +197,18 @@ def main(dataset): ...@@ -219,11 +197,18 @@ def main(dataset):
elif key == ord("q"): elif key == ord("q"):
exit(0) exit(0)
elif key == ord("p"): elif key == ord("p"):
timeout = 0 if timeout > 0 else 100 timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING
elif key == 82: # up arrow key
ir = ir - 1
continue
elif key == 83: # right arrow key elif key == 83: # right arrow key
im = im - 1
continue continue
elif key == 81: # left arrow key elif key == 81: # left arrow key
im = im - 1
ir = max(ir - 2, 0) ir = max(ir - 2, 0)
elif key == 84: # down arrow key
ir = ir - 1
im = max(im - 2, 0) im = max(im - 2, 0)
elif key == ord("s"): elif key == ord("s"):
cv.imwrite(f"{running:03d}.png", to_show) cv.imwrite(f"{running:03d}.png", to_show)
...@@ -233,6 +218,7 @@ def main(dataset): ...@@ -233,6 +218,7 @@ def main(dataset):
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
from mu_map.dataset.transform import PadCropTranform
from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.logging import add_logging_args, get_logger_by_args
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -263,6 +249,22 @@ if __name__ == "__main__": ...@@ -263,6 +249,22 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="do not remove broken slices of the mu map", help="do not remove broken slices of the mu map",
) )
parser.add_argument(
"--ids",
type=int,
nargs="*",
help="only display certain ids",
)
parser.add_argument(
"--paused",
action="store_true",
help="start in paused mode",
)
parser.add_argument(
"--pad_crop",
type=int,
help="pad crop images to this size",
)
add_logging_args(parser, defaults={"--loglevel": "DEBUG"}) add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
args = parser.parse_args() args = parser.parse_args()
...@@ -271,12 +273,17 @@ if __name__ == "__main__": ...@@ -271,12 +273,17 @@ if __name__ == "__main__":
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
logger = get_logger_by_args(args) logger = get_logger_by_args(args)
transform_normalization = (
PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
)
dataset = MuMapDataset( dataset = MuMapDataset(
args.dataset_dir, args.dataset_dir,
align=align, align=align,
discard_mu_map_slices=discard_mu_map_slices, discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file, bed_contours_file=bed_contours_file,
split_name=args.split, split_name=args.split,
transform_normalization=transform_normalization,
logger=logger, logger=logger,
) )
main(dataset) main(dataset, args.ids, paused=args.paused)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment