From d91fd6a7ca546f4cb11d518e46c843654df3907b Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 6 Jan 2023 14:48:29 +0100 Subject: [PATCH] document and small refactoring of MuMapDataset --- mu_map/data/remove_bed.py | 2 +- mu_map/dataset/default.py | 124 ++++++++++++++++++++++++++++++++++---- mu_map/dataset/patches.py | 2 +- mu_map/training/lib.py | 2 +- 4 files changed, 114 insertions(+), 16 deletions(-) diff --git a/mu_map/data/remove_bed.py b/mu_map/data/remove_bed.py index 79910ab..b8232e0 100644 --- a/mu_map/data/remove_bed.py +++ b/mu_map/data/remove_bed.py @@ -161,7 +161,7 @@ if __name__ == "__main__": ids = args.revise_ids for _i, _id in enumerate(ids): - _, mu_map = dataset.getitem_by_id(_id) + _, mu_map = dataset.get_item_by_id(_id) if str(_id) in bed_contours and not args.revise_ids: print(f"Skip {_id} because file already contains these contours") diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py index bb57235..6282f78 100644 --- a/mu_map/dataset/default.py +++ b/mu_map/dataset/default.py @@ -1,9 +1,8 @@ import os -from typing import Optional, Tuple +from typing import List, Optional, Tuple import cv2 as cv import pandas as pd -import pydicom import numpy as np import torch from torch.utils.data import Dataset @@ -23,12 +22,20 @@ from mu_map.logging import get_logger class MuMapDataset(Dataset): + """ + A dataset to map reconstructions to attenuation maps (mu maps). + + The dataset is lazy. This means that that dataset creation is + fast and images are only read into memory when their first accessed. + Thus, after the first iteration, accessing images becomes a lot + faster. + """ def __init__( self, dataset_dir: str, csv_file: str = "meta.csv", split_file: str = "split.csv", - split_name: str = None, + split_name: Optional[str] = None, images_dir: str = "images", bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, @@ -38,6 +45,35 @@ class MuMapDataset(Dataset): transform_augmentation: Transform = Transform(), logger=None, ): + """ + Create a new mu map dataset. + + Parameters + ---------- + dataset_dir: str + directory of the dataset to be loaded + csv_file: str + name of the csv file in the dataset directory containing meta information (created by mu_map.data.prepare) + split_file: str + csv file defining a split of the dataset in train/validation/test (created by mu_map.data.split) + split_name: str, optional + the name of the split which is loaded + images_dir: str + directory under `dataset_dir` containing the actual images in DICOM format + bed_contours_file: str, optional + json file containing contours around the bed for each mu map (see mu_map.data.remove_bed) + discard_mu_map_slices: bool + remove defective slices from mu maps (have to be labeled by mu_map.data.review_mu_map) + align: bool + center align reconstructions and mu maps + scatter_correction: bool + use scatter corrected reconstructions + transform_normalization: Transform + transform used for image normalization which is applied once when the image is loaded + transform_augmentation: Transform + transform used for augmentation which is applied every time `__getitem__` is called + logger: Logger, optional + """ super().__init__() self.dir = dataset_dir @@ -77,7 +113,15 @@ class MuMapDataset(Dataset): self.reconstructions = {} self.mu_maps = {} - def split_copy(self, split_name: str): + def copy(self, split_name: str): + """ + Create a copy of the dataset and modify parameters. + + Parameters + ---------- + split_name: str + the split which with which the copy is created + """ return MuMapDataset( dataset_dir=self.dir, csv_file=os.path.basename(self.csv_file), @@ -94,6 +138,18 @@ class MuMapDataset(Dataset): ) def load_image(self, _id: int): + """ + Load an image into memory. + + This function also performs all of the pre-processing (discard slices, remove bed, alignment ...). + Afterwards, the reconstruction and mu map are available from the local + dicts: `self.reconstructions` and `self.mu_maps.` + + Parameters + ---------- + _id: int + the id of the image to be loaded + """ row = self.table[self.table[headers.id] == _id].iloc[0] _id = row[headers.id] @@ -129,15 +185,42 @@ class MuMapDataset(Dataset): self.reconstructions[_id] = recon def pre_load_images(self): + """ + Load all images into memory. + """ for _id in self.table[headers.id]: self.load_image(_id) - def __getitem__(self, index: int): + def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get a reconstruction mu map pair by index. + + This method retrieves the image id of this index and call `get_item_by_id`. + + Parameters + ---------- + index: int + """ row = self.table.iloc[index] _id = row[headers.id] - return self.getitem_by_id(_id) + return self.get_item_by_id(_id) + + def get_item_by_id(self, _id: int) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get a reconstruction and mu map pair by their id. - def getitem_by_id(self, _id: int): + This methods loads the images of not yet in memory and applies the + augmentation transform before returning them. + + Parameters + ---------- + _id: int + + Returns + ------- + Tuple[torch.Tensor, torch.Tensor] + a pair of a reconstruction and the according mu map + """ if _id not in self.reconstructions: self.load_image(_id) @@ -148,14 +231,29 @@ class MuMapDataset(Dataset): return recon, mu_map - def __len__(self): + def __len__(self) -> int: + """ + Get the number of elements in this dataset. + """ return len(self.table) __all__ = [MuMapDataset.__name__] -def main(dataset, ids, paused=False): +def main(dataset: MuMapDataset, ids: Optional[List[int]] = None, paused: bool=False): + """ + Display reconstructions and mu maps in a dataset. + + Parameters + ---------- + dataset: MuMapDataset + the dataset of which elements are displayed + ids: list of int, optional + only display these ids + paused: bool + start display in paused mode + """ from mu_map.util import to_grayscale, COLOR_WHITE wname = "Dataset" @@ -163,10 +261,10 @@ def main(dataset, ids, paused=False): cv.resizeWindow(wname, 1600, 900) space = np.full((1024, 10), 239, np.uint8) - TIMEOUT_PAUSED = 0 - TIMEOUT_RUNNING = 1000 // 15 + timeout_paused = 0 + timeout_running = 1000 // 15 - timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING + timeout = timeout_paused if paused else timeout_running def to_display_image(image, _slice): _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max()) @@ -223,7 +321,7 @@ def main(dataset, ids, paused=False): elif key == ord("q"): exit(0) elif key == ord("p"): - timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING + timeout = timeout_paused if timeout > 0 else timeout_running elif key == 82: # up arrow key ir = ir - 1 continue diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py index 7487269..ab8d81f 100644 --- a/mu_map/dataset/patches.py +++ b/mu_map/dataset/patches.py @@ -98,7 +98,7 @@ class MuMapPatchDataset(MuMapDataset): ps = self.patch_size ps_z = self.patch_size_z - recon, mu_map = super().getitem_by_id(_id) + recon, mu_map = super().get_item_by_id(_id) recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0) mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0) diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index 6b9b0d9..6da9aed 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -77,7 +77,7 @@ class AbstractTraining: ( split_name, torch.utils.data.DataLoader( - dataset.split_copy(split_name), + dataset.copy(split_name), batch_size=self.batch_size, shuffle=True, pin_memory=True, -- GitLab