From bed5ba58632c847344913c8266440358638b5e14 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 27 Sep 2022 10:25:28 +0200
Subject: [PATCH] add transforms to default dataset

---
 mu_map/dataset/default.py | 43 ++++++++++++++++++++++-----------------
 1 file changed, 24 insertions(+), 19 deletions(-)

diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py
index 88b3168..7d35bc4 100644
--- a/mu_map/dataset/default.py
+++ b/mu_map/dataset/default.py
@@ -11,6 +11,8 @@ from torch.utils.data import Dataset
 from mu_map.data.prepare import headers
 from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours
 from mu_map.data.review_mu_map import discard_slices
+from mu_map.dataset.transform import Transform
+from mu_map.logging import get_logger
 
 
 def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray:
@@ -48,12 +50,18 @@ class MuMapDataset(Dataset):
         bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME,
         discard_mu_map_slices: bool = True,
         align: bool = True,
+        transform_normalization: Transform = Transform(),
+        transform_augmentation: Transform = Transform(),
+        logger=None,
     ):
         super().__init__()
 
         self.dir = dataset_dir
         self.dir_images = os.path.join(dataset_dir, images_dir)
         self.csv_file = os.path.join(dataset_dir, csv_file)
+        self.transform_normalization = transform_normalization
+        self.transform_augmentation = transform_augmentation
+        self.logger = logger if logger is not None else get_logger()
 
         self.bed_contours_file = (
             os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None
@@ -74,7 +82,7 @@ class MuMapDataset(Dataset):
         self.pre_load_images()
 
     def pre_load_images(self):
-        print("Pre-loading images ...", end="\r")
+        self.logger.debug("Pre-loading images ...")
         for i in range(len(self.table)):
             row = self.table.iloc[i]
             _id = row["id"]
@@ -96,13 +104,16 @@ class MuMapDataset(Dataset):
             mu_map = mu_map.astype(np.float32)
             mu_map = torch.from_numpy(mu_map)
             mu_map = mu_map.unsqueeze(dim=0)
-            self.mu_maps[_id] = mu_map
 
             recon = recon.astype(np.float32)
             recon = torch.from_numpy(recon)
             recon = recon.unsqueeze(dim=0)
+
+            recon, mu_map = self.transform_normalization(recon, mu_map)
+
+            self.mu_maps[_id] = mu_map
             self.reconstructions[_id] = recon
-        print("Pre-loading images done!")
+        self.logger.debug("Pre-loading images done!")
 
     def __getitem__(self, index: int):
         row = self.table.iloc[index]
@@ -111,22 +122,7 @@ class MuMapDataset(Dataset):
         recon = self.reconstructions[_id]
         mu_map = self.mu_maps[_id]
 
-        # recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc])
-        # mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map])
-
-        # recon = pydicom.dcmread(recon_file).pixel_array
-        # mu_map = pydicom.dcmread(mu_map_file).pixel_array
-
-        # if self.discard_mu_map_slices:
-        # mu_map = discard_slices(row, mu_map)
-
-        # if self.bed_contours:
-        # bed_contour = self.bed_contours[row["id"]]
-        # for i in range(mu_map.shape[0]):
-        # mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1)
-
-        # if self.align:
-        # recon = align_images(recon, mu_map)
+        recon, mu_map = self.transform_augmentation(recon, mu_map)
 
         return recon, mu_map
 
@@ -139,6 +135,12 @@ __all__ = [MuMapDataset.__name__]
 if __name__ == "__main__":
     import argparse
 
+    from mu_map.dataset.normalization import (
+        GaussianNormTransform,
+        MaxNormTransform,
+        MeanNormTransform,
+    )
+    from mu_map.logging import add_logging_args, get_logger_by_args
     from mu_map.util import to_grayscale, COLOR_WHITE
 
     parser = argparse.ArgumentParser(
@@ -163,17 +165,20 @@ if __name__ == "__main__":
         action="store_true",
         help="do not remove broken slices of the mu map",
     )
+    add_logging_args(parser, defaults={"--loglevel": "DEBUG"})
     args = parser.parse_args()
 
     align = not args.unaligned
     discard_mu_map_slices = not args.full_mu_map
     bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME
+    logger = get_logger_by_args(args)
 
     dataset = MuMapDataset(
         args.dataset_dir,
         align=align,
         discard_mu_map_slices=discard_mu_map_slices,
         bed_contours_file=bed_contours_file,
+        logger=logger,
     )
 
     wname = "Dataset"
-- 
GitLab