diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py index 94fa52aaca078dd6fbd61736d1c626d3bd2d036d..1568919a73905899e55062221701388ec0b71eb1 100644 --- a/mu_map/dataset/default.py +++ b/mu_map/dataset/default.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional, Tuple import cv2 as cv import pandas as pd @@ -9,48 +9,18 @@ import torch from torch.utils.data import Dataset from mu_map.data.prepare import headers -from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours +from mu_map.data.remove_bed import ( + DEFAULT_BED_CONTOURS_FILENAME, + load_contours, + remove_bed, +) from mu_map.data.review_mu_map import discard_slices from mu_map.data.split import split_csv from mu_map.dataset.transform import Transform +from mu_map.dataset.util import align_images, load_dcm_img from mu_map.logging import get_logger -""" -Since DICOM images only allow images stored in short integer format, -the Siemens scanner software multiplies values by a factor before storing -so that no precision is lost. -The scale can be found in this private DICOM tag. -""" -DCM_TAG_PIXEL_SCALE_FACTOR = 0x00331038 - - -def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray: - """ - Align one image to another on the first axis (z-axis). - It is assumed that the second image has less slices than the first. - Then, the first image is shortened in a way that the centers of both images lie on top of each other. - - :param image_1: the image to be aligned - :param image_2: the image to which image_1 is aligned - :return: the aligned image_1 - """ - assert ( - image_1.shape[0] > image_2.shape[0] - ), f"Alignment is based on the fact that image 1 has more slices {image_1.shape[0]} than image_2 {image_.shape[0]}" - - # central slice of image 2 - c_2 = image_2.shape[0] // 2 - # image to the left and right of the center - left = c_2 - right = image_2.shape[0] - c_2 - - # central slice of image 1 - c_1 = image_1.shape[0] // 2 - # select center and same amount to the left/right as image_2 - return image_1[(c_1 - left) : (c_1 + right)] - - class MuMapDataset(Dataset): def __init__( self, @@ -102,48 +72,39 @@ class MuMapDataset(Dataset): self.reconstructions = {} self.mu_maps = {} - self.pre_load_images() - - def pre_load_images(self): - self.logger.debug("Pre-loading images ...") - for i in range(len(self.table)): - row = self.table.iloc[i] - _id = row[headers.id] - - mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) - mu_map = pydicom.dcmread(mu_map_file) - mu_map = mu_map.pixel_array / mu_map[DCM_TAG_PIXEL_SCALE_FACTOR].value - if self.discard_mu_map_slices: - mu_map = discard_slices(row, mu_map) - if self.bed_contours: - if _id in self.bed_contours: - bed_contour = self.bed_contours[_id] - for i in range(mu_map.shape[0]): - mu_map[i] = cv.drawContours( - mu_map[i], [bed_contour], -1, 0.0, -1 - ) - else: - logger.warning(f"Could not find bed contour for id {_id}") - - recon_file = os.path.join(self.dir_images, row[self.header_recon]) - recon = pydicom.dcmread(recon_file) - recon = recon.pixel_array / recon[DCM_TAG_PIXEL_SCALE_FACTOR].value - if self.align: - recon = align_images(recon, mu_map) - - mu_map = mu_map.astype(np.float32) - mu_map = torch.from_numpy(mu_map) - mu_map = mu_map.unsqueeze(dim=0) - - recon = recon.astype(np.float32) - recon = torch.from_numpy(recon) - recon = recon.unsqueeze(dim=0) - - recon, mu_map = self.transform_normalization(recon, mu_map) - - self.mu_maps[_id] = mu_map - self.reconstructions[_id] = recon - self.logger.debug("Pre-loading images done!") + + def load_image(self, _id: int): + row = self.table[self.table[headers.id] == _id].iloc[0] + _id = row[headers.id] + + mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) + mu_map = load_dcm_img(mu_map_file) + if self.discard_mu_map_slices: + mu_map = discard_slices(row, mu_map) + if self.bed_contours: + if _id in self.bed_contours: + bed_contour = self.bed_contours[_id] + mu_map = remove_bed(mu_map, bed_contour) + else: + logger.warning(f"Could not find bed contour for id {_id}") + + recon_file = os.path.join(self.dir_images, row[self.header_recon]) + recon = load_dcm_img(recon_file) + if self.align: + recon, mu_map = align_images(recon, mu_map) + + mu_map = mu_map.astype(np.float32) + mu_map = torch.from_numpy(mu_map) + mu_map = mu_map.unsqueeze(dim=0) + + recon = recon.astype(np.float32) + recon = torch.from_numpy(recon) + recon = recon.unsqueeze(dim=0) + + recon, mu_map = self.transform_normalization(recon, mu_map) + + self.mu_maps[_id] = mu_map + self.reconstructions[_id] = recon def __getitem__(self, index: int): row = self.table.iloc[index] @@ -151,6 +112,9 @@ class MuMapDataset(Dataset): return self.getitem_by_id(_id) def getitem_by_id(self, _id: int): + if _id not in self.reconstructions: + self.load_image(_id) + recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] @@ -165,7 +129,7 @@ class MuMapDataset(Dataset): __all__ = [MuMapDataset.__name__] -def main(dataset): +def main(dataset, ids, paused=False): from mu_map.util import to_grayscale, COLOR_WHITE wname = "Dataset" @@ -173,7 +137,10 @@ def main(dataset): cv.resizeWindow(wname, 1600, 900) space = np.full((1024, 10), 239, np.uint8) - timeout = 100 + TIMEOUT_PAUSED = 0 + TIMEOUT_RUNNING = 1000 // 15 + + timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING def to_display_image(image, _slice): _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) @@ -187,8 +154,16 @@ def main(dataset): def combine_images(images, slices): image_1 = to_display_image(images[0], slices[0]) image_2 = to_display_image(images[1], slices[1]) - space = np.full((image_1.shape[0], 10), 239, np.uint8) - return np.hstack((image_1, space, image_2)) + + image_1 = image_1.repeat(3).reshape((*image_1.shape, 3)) + image_2 = image_2.repeat(3).reshape((*image_2.shape, 3)) + + image_3_2 = cv.applyColorMap(image_1, cv.COLORMAP_INFERNO) + image_3_1 = image_2.copy() + image_3 = cv.addWeighted(image_3_1, 0.8, image_3_2, 0.4, 0.0) + + space = np.full((image_1.shape[0], 10, 3), 239, np.uint8) + return np.hstack((image_1, space, image_3, space, image_2)) for i in range(len(dataset)): ir = 0 @@ -197,6 +172,9 @@ def main(dataset): row = dataset.table.iloc[i] _id = row[headers.id] + if ids is not None and _id not in ids: + continue + recon, mu_map = dataset[i] recon = recon.squeeze().numpy() mu_map = mu_map.squeeze().numpy() @@ -219,11 +197,18 @@ def main(dataset): elif key == ord("q"): exit(0) elif key == ord("p"): - timeout = 0 if timeout > 0 else 100 + timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING + elif key == 82: # up arrow key + ir = ir - 1 + continue elif key == 83: # right arrow key + im = im - 1 continue elif key == 81: # left arrow key + im = im - 1 ir = max(ir - 2, 0) + elif key == 84: # down arrow key + ir = ir - 1 im = max(im - 2, 0) elif key == ord("s"): cv.imwrite(f"{running:03d}.png", to_show) @@ -233,6 +218,7 @@ def main(dataset): if __name__ == "__main__": import argparse + from mu_map.dataset.transform import PadCropTranform from mu_map.logging import add_logging_args, get_logger_by_args parser = argparse.ArgumentParser( @@ -263,6 +249,22 @@ if __name__ == "__main__": action="store_true", help="do not remove broken slices of the mu map", ) + parser.add_argument( + "--ids", + type=int, + nargs="*", + help="only display certain ids", + ) + parser.add_argument( + "--paused", + action="store_true", + help="start in paused mode", + ) + parser.add_argument( + "--pad_crop", + type=int, + help="pad crop images to this size", + ) add_logging_args(parser, defaults={"--loglevel": "DEBUG"}) args = parser.parse_args() @@ -271,12 +273,17 @@ if __name__ == "__main__": bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME logger = get_logger_by_args(args) + transform_normalization = ( + PadCropTranform(dim=3, size=args.pad_crop) if args.pad_crop else Transform() + ) + dataset = MuMapDataset( args.dataset_dir, align=align, discard_mu_map_slices=discard_mu_map_slices, bed_contours_file=bed_contours_file, split_name=args.split, + transform_normalization=transform_normalization, logger=logger, ) - main(dataset) + main(dataset, args.ids, paused=args.paused)