Newer
Older
from typing import Optional, Tuple
import torch
from mu_map.data.remove_bed import (
DEFAULT_BED_CONTOURS_FILENAME,
load_contours,
remove_bed,
)
from mu_map.data.review_mu_map import discard_slices
from mu_map.data.split import split_csv
from mu_map.dataset.transform import Transform
from mu_map.dataset.util import align_images, load_dcm_img
class MuMapDataset(Dataset):
def __init__(
self,
dataset_dir: str,
csv_file: str = "meta.csv",
split_file: str = "split.csv",
split_name: str = None,
images_dir: str = "images",
bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
discard_mu_map_slices: bool = True,
scatter_correction: bool = False,
transform_normalization: Transform = Transform(),
transform_augmentation: Transform = Transform(),
logger=None,
self.dir = dataset_dir
self.dir_images = os.path.join(dataset_dir, images_dir)
self.csv_file = os.path.join(dataset_dir, csv_file)
self.split_file = os.path.join(dataset_dir, split_file)
self.transform_normalization = transform_normalization
self.transform_augmentation = transform_augmentation
self.logger = logger if logger is not None else get_logger(name=MuMapDataset.__name__)
self.bed_contours_file = (
os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
)
self.bed_contours = (
load_contours(self.bed_contours_file) if bed_contours_file else None
)
# read CSV file and from that access DICOM files
self.table = pd.read_csv(self.csv_file)
if split_name:
self.table = split_csv(self.table, self.split_file)[split_name]
self.table[headers.id] = self.table[headers.id].apply(int)
self.discard_mu_map_slices = discard_mu_map_slices
self.scatter_correction = scatter_correction
self.header_recon = (
headers.file_recon_nac_sc
if self.scatter_correction
else headers.file_recon_nac_nsc
)
self.reconstructions = {}
self.mu_maps = {}
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
def load_image(self, _id: int):
row = self.table[self.table[headers.id] == _id].iloc[0]
_id = row[headers.id]
mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
mu_map = load_dcm_img(mu_map_file)
if self.discard_mu_map_slices:
mu_map = discard_slices(row, mu_map)
if self.bed_contours:
if _id in self.bed_contours:
bed_contour = self.bed_contours[_id]
mu_map = remove_bed(mu_map, bed_contour)
else:
logger.warning(f"Could not find bed contour for id {_id}")
recon_file = os.path.join(self.dir_images, row[self.header_recon])
recon = load_dcm_img(recon_file)
if self.align:
recon, mu_map = align_images(recon, mu_map)
mu_map = mu_map.astype(np.float32)
mu_map = torch.from_numpy(mu_map)
mu_map = mu_map.unsqueeze(dim=0)
recon = recon.astype(np.float32)
recon = torch.from_numpy(recon)
recon = recon.unsqueeze(dim=0)
recon, mu_map = self.transform_normalization(recon, mu_map)
self.mu_maps[_id] = mu_map
self.reconstructions[_id] = recon
def pre_load_images(self):
for _id in self.table[headers.id]:
self.load_image(_id)
def __getitem__(self, index: int):
row = self.table.iloc[index]
_id = row[headers.id]
return self.getitem_by_id(_id)
def getitem_by_id(self, _id: int):
if _id not in self.reconstructions:
self.load_image(_id)
recon = self.reconstructions[_id]
mu_map = self.mu_maps[_id]
recon, mu_map = self.transform_augmentation(recon, mu_map)
def main(dataset, ids, paused=False):
from mu_map.util import to_grayscale, COLOR_WHITE
cv.namedWindow(wname, cv.WINDOW_NORMAL)
cv.resizeWindow(wname, 1600, 900)
space = np.full((1024, 10), 239, np.uint8)
TIMEOUT_PAUSED = 0
TIMEOUT_RUNNING = 1000 // 15
timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING
def to_display_image(image, _slice):
_image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
_image = cv.resize(_image, (1024, 1024), cv.INTER_AREA)
_text = f"{str(_slice):>{len(str(image.shape[0]))}}/{str(image.shape[0])}"
_image = cv.putText(
_image, _text, (0, 30), cv.FONT_HERSHEY_SIMPLEX, 1, COLOR_WHITE, 3
)
def combine_images(images, slices):
image_1 = to_display_image(images[0], slices[0])
image_2 = to_display_image(images[1], slices[1])
image_1 = image_1.repeat(3).reshape((*image_1.shape, 3))
image_2 = image_2.repeat(3).reshape((*image_2.shape, 3))
image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO)
image_3_1 = image_2.copy()
image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0)
space = np.full((image_1.shape[0], 10, 3), 239, np.uint8)
return np.hstack((image_1, space, image_3, space, image_2))
for i in range(len(dataset)):
ir = 0
im = 0
row = dataset.table.iloc[i]
_id = row[headers.id]
if ids is not None and _id not in ids:
continue
recon = recon.squeeze().numpy()
mu_map = mu_map.squeeze().numpy()
print(f"{(i+1):>{len(str(len(dataset)))}}/{len(dataset)} - ID: {_id}", end="\r")
cv.imshow(wname, combine_images((recon, mu_map), (ir, im)))
key = cv.waitKey(timeout)
while True:
ir = (ir + 1) % recon.shape[0]
im = (im + 1) % mu_map.shape[0]
to_show = combine_images((recon, mu_map), (ir, im))
cv.imshow(wname, to_show)
timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING
elif key == 82: # up arrow key
ir = ir - 1
continue
im = im - 1
im = im - 1
elif key == 84: # down arrow key
ir = ir - 1
elif key == ord("s"):
cv.imwrite(f"{running:03d}.png", to_show)
running += 1

Tamino Huxohl
committed
if __name__ == "__main__":
import argparse
from mu_map.dataset.transform import PadCropTranform

Tamino Huxohl
committed
from mu_map.logging import add_logging_args, get_logger_by_args
parser = argparse.ArgumentParser(
description="Visualize the images of a MuMapDataset",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"dataset_dir", type=str, help="the directory from which the dataset is loaded"
)
parser.add_argument(
"--split",
type=str,
choices=["train", "validation", "test"],
help="choose the split of the data for the dataset",
)

Tamino Huxohl
committed
parser.add_argument(
"--unaligned",
action="store_true",
help="do not perform center alignment of reconstruction an mu-map slices",
)
parser.add_argument(
"--show_bed",
action="store_true",
help="do not remove the bed contour from the mu map",
)
parser.add_argument(
"--full_mu_map",
action="store_true",
help="do not remove broken slices of the mu map",
)
parser.add_argument(
"--ids",
type=int,
nargs="*",
help="only display certain ids",
)
parser.add_argument(
"--paused",
action="store_true",
help="start in paused mode",
)
parser.add_argument(
"--pad_crop",
type=int,
help="pad crop images to this size",
)

Tamino Huxohl
committed
add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
args = parser.parse_args()
align = not args.unaligned
discard_mu_map_slices = not args.full_mu_map
bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
logger = get_logger_by_args(args)
transform_normalization = (
PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform()
)

Tamino Huxohl
committed
dataset = MuMapDataset(
args.dataset_dir,
align=align,
discard_mu_map_slices=discard_mu_map_slices,
bed_contours_file=bed_contours_file,
transform_normalization=transform_normalization,

Tamino Huxohl
committed
logger=logger,
)
main(dataset, args.ids, paused=args.paused)