From 1e16732e6ac560bf4f8227fc4df560da331575d6 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Wed, 18 Jan 2023 13:42:29 +0100
Subject: [PATCH] introduce function to init random seeds in training lib

---
 mu_map/training/cgan.py     | 24 ++++++------------------
 mu_map/training/distance.py | 25 ++++++-------------------
 mu_map/training/lib.py      | 24 ++++++++++++++++++++++++
 3 files changed, 36 insertions(+), 37 deletions(-)

diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index 989a219..bbec323 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -186,15 +186,12 @@ if __name__ == "__main__":
     import numpy as np
 
     from mu_map.dataset.patches import MuMapPatchDataset
-    from mu_map.dataset.normalization import (
-        MeanNormTransform,
-        MaxNormTransform,
-        GaussianNormTransform,
-    )
+    from mu_map.dataset.normalization import norm_choices, norm_by_str
     from mu_map.dataset.transform import PadCropTranform, SequenceTransform
     from mu_map.logging import add_logging_args, get_logger_by_args
     from mu_map.models.unet import UNet
     from mu_map.models.discriminator import Discriminator, PatchDiscriminator
+    from mu_map.training.lib import init_random_seed
 
     parser = argparse.ArgumentParser(
         description="Train a UNet model to predict μ-maps from reconstructed images",
@@ -220,7 +217,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--input_norm",
         type=str,
-        choices=["none", "mean", "max", "gaussian"],
+        choices=norm_choices,
         default="mean",
         help="type of normalization applied to the reconstructions",
     )
@@ -365,19 +362,10 @@ if __name__ == "__main__":
 
     device = torch.device(args.device)
 
-    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+    seed = init_random_seed(args.seed)
     logger.info(f"Seed: {args.seed}")
-    random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    np.random.seed(args.seed)
-
-    transform_normalization = None
-    if args.input_norm == "mean":
-        transform_normalization = MeanNormTransform()
-    elif args.input_norm == "max":
-        transform_normalization = MaxNormTransform()
-    elif args.input_norm == "gaussian":
-        transform_normalization = GaussianNormTransform()
+
+    transform_normalization = norm_by_str(args.input_norm)
     transform_normalization = SequenceTransform(
         [transform_normalization, PadCropTranform(dim=3, size=32)]
     )
diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py
index 5e29323..df2dfb9 100644
--- a/mu_map/training/distance.py
+++ b/mu_map/training/distance.py
@@ -72,13 +72,10 @@ if __name__ == "__main__":
     import numpy as np
 
     from mu_map.dataset.patches import MuMapPatchDataset
-    from mu_map.dataset.normalization import (
-        MeanNormTransform,
-        MaxNormTransform,
-        GaussianNormTransform,
-    )
+    from mu_map.dataset.normalization import norm_choices, norm_by_str
     from mu_map.logging import add_logging_args, get_logger_by_args
     from mu_map.models.unet import UNet
+    from mu_map.training.lib import init_random_seed
 
     parser = argparse.ArgumentParser(
         description="Train a UNet model to predict μ-maps from reconstructed scatter images",
@@ -104,7 +101,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--input_norm",
         type=str,
-        choices=["none", "mean", "max", "gaussian"],
+        choices=norm_choices,
         default="mean",
         help="type of normalization applied to the reconstructions",
     )
@@ -231,20 +228,10 @@ if __name__ == "__main__":
     logger = get_logger_by_args(args)
     logger.info(args)
 
-    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
-    logger.info(f"Seed: {args.seed}")
-    random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    np.random.seed(args.seed)
-
-    transform_normalization = None
-    if args.input_norm == "mean":
-        transform_normalization = MeanNormTransform()
-    elif args.input_norm == "max":
-        transform_normalization = MaxNormTransform()
-    elif args.input_norm == "gaussian":
-        transform_normalization = GaussianNormTransform()
+    seed = init_random_seed(args.seed)
+    logger.info(f"Seed: {seed}")
 
+    transform_normalization = norm_by_str(args.input_norm)
     dataset = MuMapPatchDataset(
         args.dataset_dir,
         patches_per_image=args.number_of_patches,
diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
index c7b5cc4..57bb847 100644
--- a/mu_map/training/lib.py
+++ b/mu_map/training/lib.py
@@ -4,6 +4,7 @@ Module functioning as a library for training related code.
 from dataclasses import dataclass
 from logging import Logger
 import os
+import random
 from typing import Dict, List, Optional
 import sys
 
@@ -15,6 +16,29 @@ from mu_map.dataset.default import MuMapDataset
 from mu_map.logging import get_logger
 
 
+def init_random_seed(seed: Optional[int] = None) -> int:
+    """
+    Set the seed for all RNGs (default python, numpy and torch).
+
+    Parameters
+    ----------
+    seed: int, optional
+        the seed to be used which is generated if not provided
+
+    Returns
+    -------
+    int
+        the randoms seed used
+    """
+    seed = seed if seed is not None else random.randint(0, 2**32 - 1)
+
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+
+    return seed
+
+
 @dataclass
 class TrainingParams:
     """
-- 
GitLab