From d91fd6a7ca546f4cb11d518e46c843654df3907b Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 6 Jan 2023 14:48:29 +0100
Subject: [PATCH] document and small refactoring of MuMapDataset

---
 mu_map/data/remove_bed.py |   2 +-
 mu_map/dataset/default.py | 124 ++++++++++++++++++++++++++++++++++----
 mu_map/dataset/patches.py |   2 +-
 mu_map/training/lib.py    |   2 +-
 4 files changed, 114 insertions(+), 16 deletions(-)

diff --git a/mu_map/data/remove_bed.py b/mu_map/data/remove_bed.py
index 79910ab..b8232e0 100644
--- a/mu_map/data/remove_bed.py
+++ b/mu_map/data/remove_bed.py
@@ -161,7 +161,7 @@ if __name__ == "__main__":
         ids = args.revise_ids
 
     for _i, _id in enumerate(ids):
-        _, mu_map = dataset.getitem_by_id(_id)
+        _, mu_map = dataset.get_item_by_id(_id)
 
         if str(_id) in bed_contours and not args.revise_ids:
             print(f"Skip {_id} because file already contains these contours")
diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py
index bb57235..6282f78 100644
--- a/mu_map/dataset/default.py
+++ b/mu_map/dataset/default.py
@@ -1,9 +1,8 @@
 import os
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
 
 import cv2 as cv
 import pandas as pd
-import pydicom
 import numpy as np
 import torch
 from torch.utils.data import Dataset
@@ -23,12 +22,20 @@ from mu_map.logging import get_logger
 
 
 class MuMapDataset(Dataset):
+    """
+    A dataset to map reconstructions to attenuation maps (mu maps).
+
+    The dataset is lazy. This means that that dataset creation is
+    fast and images are only read into memory when their first accessed.
+    Thus, after the first iteration, accessing images becomes a lot
+    faster.
+    """
     def __init__(
         self,
         dataset_dir: str,
         csv_file: str = "meta.csv",
         split_file: str = "split.csv",
-        split_name: str = None,
+        split_name: Optional[str] = None,
         images_dir: str = "images",
         bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
         discard_mu_map_slices: bool = True,
@@ -38,6 +45,35 @@ class MuMapDataset(Dataset):
         transform_augmentation: Transform = Transform(),
         logger=None,
     ):
+        """
+        Create a new mu map dataset.
+
+        Parameters
+        ----------
+        dataset_dir: str
+            directory of the dataset to be loaded
+        csv_file: str
+            name of the csv file in the dataset directory containing meta information (created by mu_map.data.prepare)
+        split_file: str
+            csv file defining a split of the dataset in train/validation/test (created by mu_map.data.split)
+        split_name: str, optional
+            the name of the split which is loaded
+        images_dir: str
+            directory under `dataset_dir` containing the actual images in DICOM format
+        bed_contours_file: str, optional
+            json file containing contours around the bed for each mu map (see mu_map.data.remove_bed)
+        discard_mu_map_slices: bool
+            remove defective slices from mu maps (have to be labeled by mu_map.data.review_mu_map)
+        align: bool
+            center align reconstructions and mu maps
+        scatter_correction: bool
+            use scatter corrected reconstructions
+        transform_normalization: Transform
+            transform used for image normalization which is applied once when the image is loaded
+        transform_augmentation: Transform
+            transform used for augmentation which is applied every time `__getitem__` is called
+        logger: Logger, optional 
+        """
         super().__init__()
 
         self.dir = dataset_dir
@@ -77,7 +113,15 @@ class MuMapDataset(Dataset):
         self.reconstructions = {}
         self.mu_maps = {}
 
-    def split_copy(self, split_name: str):
+    def copy(self, split_name: str):
+        """
+        Create a copy of the dataset and modify parameters.
+
+        Parameters
+        ----------
+        split_name: str
+            the split which with which the copy is created
+        """
         return MuMapDataset(
             dataset_dir=self.dir,
             csv_file=os.path.basename(self.csv_file),
@@ -94,6 +138,18 @@ class MuMapDataset(Dataset):
         )
 
     def load_image(self, _id: int):
+        """
+        Load an image into memory.
+
+        This function also performs all of the pre-processing (discard slices, remove bed, alignment ...).
+        Afterwards, the reconstruction and mu map are available from the local
+        dicts: `self.reconstructions` and `self.mu_maps.`
+
+        Parameters
+        ----------
+        _id: int
+            the id of the image to be loaded
+        """
         row = self.table[self.table[headers.id] == _id].iloc[0]
         _id = row[headers.id]
 
@@ -129,15 +185,42 @@ class MuMapDataset(Dataset):
         self.reconstructions[_id] = recon
 
     def pre_load_images(self):
+        """
+        Load all images into memory.
+        """
         for _id in self.table[headers.id]:
             self.load_image(_id)
 
-    def __getitem__(self, index: int):
+    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Get a reconstruction mu map pair by index.
+
+        This method retrieves the image id of this index and call `get_item_by_id`.
+
+        Parameters
+        ----------
+        index: int
+        """
         row = self.table.iloc[index]
         _id = row[headers.id]
-        return self.getitem_by_id(_id)
+        return self.get_item_by_id(_id)
+
+    def get_item_by_id(self, _id: int) -> Tuple[torch.Tensor, torch.Tensor]:
+        """
+        Get a reconstruction and mu map pair by their id.
 
-    def getitem_by_id(self, _id: int):
+        This methods loads the images of not yet in memory and applies the
+        augmentation transform before returning them.
+
+        Parameters
+        ----------
+        _id: int
+
+        Returns
+        -------
+        Tuple[torch.Tensor, torch.Tensor]
+            a pair of a reconstruction and the according mu map
+        """
         if _id not in self.reconstructions:
             self.load_image(_id)
 
@@ -148,14 +231,29 @@ class MuMapDataset(Dataset):
 
         return recon, mu_map
 
-    def __len__(self):
+    def __len__(self) -> int:
+        """
+        Get the number of elements in this dataset.
+        """
         return len(self.table)
 
 
 __all__ = [MuMapDataset.__name__]
 
 
-def main(dataset, ids, paused=False):
+def main(dataset: MuMapDataset, ids: Optional[List[int]] = None, paused: bool=False):
+    """
+    Display reconstructions and mu maps in a dataset.
+
+    Parameters
+    ----------
+    dataset: MuMapDataset
+        the dataset of which elements are displayed
+    ids: list of int, optional
+        only display these ids
+    paused: bool
+        start display in paused mode 
+    """
     from mu_map.util import to_grayscale, COLOR_WHITE
 
     wname = "Dataset"
@@ -163,10 +261,10 @@ def main(dataset, ids, paused=False):
     cv.resizeWindow(wname, 1600, 900)
     space = np.full((1024, 10), 239, np.uint8)
 
-    TIMEOUT_PAUSED = 0
-    TIMEOUT_RUNNING = 1000 // 15
+    timeout_paused = 0
+    timeout_running = 1000 // 15
 
-    timeout = TIMEOUT_PAUSED if paused else TIMEOUT_RUNNING
+    timeout = timeout_paused if paused else timeout_running
 
     def to_display_image(image, _slice):
         _image = to_grayscale(image[_slice], min_val=image.min(), max_val=image.max())
@@ -223,7 +321,7 @@ def main(dataset, ids, paused=False):
             elif key == ord("q"):
                 exit(0)
             elif key == ord("p"):
-                timeout = TIMEOUT_PAUSED if timeout > 0 else TIMEOUT_RUNNING
+                timeout = timeout_paused if timeout > 0 else timeout_running
             elif key == 82:  # up arrow key
                 ir = ir - 1
                 continue
diff --git a/mu_map/dataset/patches.py b/mu_map/dataset/patches.py
index 7487269..ab8d81f 100644
--- a/mu_map/dataset/patches.py
+++ b/mu_map/dataset/patches.py
@@ -98,7 +98,7 @@ class MuMapPatchDataset(MuMapDataset):
         ps = self.patch_size
         ps_z = self.patch_size_z
 
-        recon, mu_map = super().getitem_by_id(_id)
+        recon, mu_map = super().get_item_by_id(_id)
 
         recon = torch.nn.functional.pad(recon, padding, mode="constant", value=0)
         mu_map = torch.nn.functional.pad(mu_map, padding, mode="constant", value=0)
diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
index 6b9b0d9..6da9aed 100644
--- a/mu_map/training/lib.py
+++ b/mu_map/training/lib.py
@@ -77,7 +77,7 @@ class AbstractTraining:
                 (
                     split_name,
                     torch.utils.data.DataLoader(
-                        dataset.split_copy(split_name),
+                        dataset.copy(split_name),
                         batch_size=self.batch_size,
                         shuffle=True,
                         pin_memory=True,
-- 
GitLab