From bed5ba58632c847344913c8266440358638b5e14 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 27 Sep 2022 10:25:28 +0200 Subject: [PATCH] add transforms to default dataset --- mu_map/dataset/default.py | 43 ++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/mu_map/dataset/default.py b/mu_map/dataset/default.py index 88b3168..7d35bc4 100644 --- a/mu_map/dataset/default.py +++ b/mu_map/dataset/default.py @@ -11,6 +11,8 @@ from torch.utils.data import Dataset from mu_map.data.prepare import headers from mu_map.data.remove_bed import DEFAULT_BED_CONTOURS_FILENAME, load_contours from mu_map.data.review_mu_map import discard_slices +from mu_map.dataset.transform import Transform +from mu_map.logging import get_logger def align_images(image_1: np.ndarray, image_2: np.ndarray) -> np.ndarray: @@ -48,12 +50,18 @@ class MuMapDataset(Dataset): bed_contours_file: Optional[str] = DEFAULT_BED_CONTOURS_FILENAME, discard_mu_map_slices: bool = True, align: bool = True, + transform_normalization: Transform = Transform(), + transform_augmentation: Transform = Transform(), + logger=None, ): super().__init__() self.dir = dataset_dir self.dir_images = os.path.join(dataset_dir, images_dir) self.csv_file = os.path.join(dataset_dir, csv_file) + self.transform_normalization = transform_normalization + self.transform_augmentation = transform_augmentation + self.logger = logger if logger is not None else get_logger() self.bed_contours_file = ( os.path.join(dataset_dir, bed_contours_file) if bed_contours_file else None @@ -74,7 +82,7 @@ class MuMapDataset(Dataset): self.pre_load_images() def pre_load_images(self): - print("Pre-loading images ...", end="\r") + self.logger.debug("Pre-loading images ...") for i in range(len(self.table)): row = self.table.iloc[i] _id = row["id"] @@ -96,13 +104,16 @@ class MuMapDataset(Dataset): mu_map = mu_map.astype(np.float32) mu_map = torch.from_numpy(mu_map) mu_map = mu_map.unsqueeze(dim=0) - self.mu_maps[_id] = mu_map recon = recon.astype(np.float32) recon = torch.from_numpy(recon) recon = recon.unsqueeze(dim=0) + + recon, mu_map = self.transform_normalization(recon, mu_map) + + self.mu_maps[_id] = mu_map self.reconstructions[_id] = recon - print("Pre-loading images done!") + self.logger.debug("Pre-loading images done!") def __getitem__(self, index: int): row = self.table.iloc[index] @@ -111,22 +122,7 @@ class MuMapDataset(Dataset): recon = self.reconstructions[_id] mu_map = self.mu_maps[_id] - # recon_file = os.path.join(self.dir_images, row[headers.file_recon_nac_nsc]) - # mu_map_file = os.path.join(self.dir_images, row[headers.file_mu_map]) - - # recon = pydicom.dcmread(recon_file).pixel_array - # mu_map = pydicom.dcmread(mu_map_file).pixel_array - - # if self.discard_mu_map_slices: - # mu_map = discard_slices(row, mu_map) - - # if self.bed_contours: - # bed_contour = self.bed_contours[row["id"]] - # for i in range(mu_map.shape[0]): - # mu_map[i] = cv.drawContours(mu_map[i], [bed_contour], -1, 0.0, -1) - - # if self.align: - # recon = align_images(recon, mu_map) + recon, mu_map = self.transform_augmentation(recon, mu_map) return recon, mu_map @@ -139,6 +135,12 @@ __all__ = [MuMapDataset.__name__] if __name__ == "__main__": import argparse + from mu_map.dataset.normalization import ( + GaussianNormTransform, + MaxNormTransform, + MeanNormTransform, + ) + from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.util import to_grayscale, COLOR_WHITE parser = argparse.ArgumentParser( @@ -163,17 +165,20 @@ if __name__ == "__main__": action="store_true", help="do not remove broken slices of the mu map", ) + add_logging_args(parser, defaults={"--loglevel": "DEBUG"}) args = parser.parse_args() align = not args.unaligned discard_mu_map_slices = not args.full_mu_map bed_contours_file = None if args.show_bed else DEFAULT_BED_CONTOURS_FILENAME + logger = get_logger_by_args(args) dataset = MuMapDataset( args.dataset_dir, align=align, discard_mu_map_slices=discard_mu_map_slices, bed_contours_file=bed_contours_file, + logger=logger, ) wname = "Dataset" -- GitLab