diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index 989a219c5fb390aae105b6c0a94e7b21f649daf6..bbec3236882e30797f5269d32092e24b3e2dfdec 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -186,15 +186,12 @@ if __name__ == "__main__":
     import numpy as np
 
     from mu_map.dataset.patches import MuMapPatchDataset
-    from mu_map.dataset.normalization import (
-        MeanNormTransform,
-        MaxNormTransform,
-        GaussianNormTransform,
-    )
+    from mu_map.dataset.normalization import norm_choices, norm_by_str
     from mu_map.dataset.transform import PadCropTranform, SequenceTransform
     from mu_map.logging import add_logging_args, get_logger_by_args
     from mu_map.models.unet import UNet
     from mu_map.models.discriminator import Discriminator, PatchDiscriminator
+    from mu_map.training.lib import init_random_seed
 
     parser = argparse.ArgumentParser(
         description="Train a UNet model to predict μ-maps from reconstructed images",
@@ -220,7 +217,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--input_norm",
         type=str,
-        choices=["none", "mean", "max", "gaussian"],
+        choices=norm_choices,
         default="mean",
         help="type of normalization applied to the reconstructions",
     )
@@ -365,19 +362,10 @@ if __name__ == "__main__":
 
     device = torch.device(args.device)
 
-    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+    seed = init_random_seed(args.seed)
     logger.info(f"Seed: {args.seed}")
-    random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    np.random.seed(args.seed)
-
-    transform_normalization = None
-    if args.input_norm == "mean":
-        transform_normalization = MeanNormTransform()
-    elif args.input_norm == "max":
-        transform_normalization = MaxNormTransform()
-    elif args.input_norm == "gaussian":
-        transform_normalization = GaussianNormTransform()
+
+    transform_normalization = norm_by_str(args.input_norm)
     transform_normalization = SequenceTransform(
         [transform_normalization, PadCropTranform(dim=3, size=32)]
     )
diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py
index 5e29323665a93b07e5ce6a274346f225f81ee4b4..df2dfb95d9f5dff506e23673b3dcfea6388c7f84 100644
--- a/mu_map/training/distance.py
+++ b/mu_map/training/distance.py
@@ -72,13 +72,10 @@ if __name__ == "__main__":
     import numpy as np
 
     from mu_map.dataset.patches import MuMapPatchDataset
-    from mu_map.dataset.normalization import (
-        MeanNormTransform,
-        MaxNormTransform,
-        GaussianNormTransform,
-    )
+    from mu_map.dataset.normalization import norm_choices, norm_by_str
     from mu_map.logging import add_logging_args, get_logger_by_args
     from mu_map.models.unet import UNet
+    from mu_map.training.lib import init_random_seed
 
     parser = argparse.ArgumentParser(
         description="Train a UNet model to predict μ-maps from reconstructed scatter images",
@@ -104,7 +101,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--input_norm",
         type=str,
-        choices=["none", "mean", "max", "gaussian"],
+        choices=norm_choices,
         default="mean",
         help="type of normalization applied to the reconstructions",
     )
@@ -231,20 +228,10 @@ if __name__ == "__main__":
     logger = get_logger_by_args(args)
     logger.info(args)
 
-    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
-    logger.info(f"Seed: {args.seed}")
-    random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    np.random.seed(args.seed)
-
-    transform_normalization = None
-    if args.input_norm == "mean":
-        transform_normalization = MeanNormTransform()
-    elif args.input_norm == "max":
-        transform_normalization = MaxNormTransform()
-    elif args.input_norm == "gaussian":
-        transform_normalization = GaussianNormTransform()
+    seed = init_random_seed(args.seed)
+    logger.info(f"Seed: {seed}")
 
+    transform_normalization = norm_by_str(args.input_norm)
     dataset = MuMapPatchDataset(
         args.dataset_dir,
         patches_per_image=args.number_of_patches,
diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
index c7b5cc446106754f9770f093de45311da82f062f..57bb8473a34ba64c2862534bb0dab9bad261a31e 100644
--- a/mu_map/training/lib.py
+++ b/mu_map/training/lib.py
@@ -4,6 +4,7 @@ Module functioning as a library for training related code.
 from dataclasses import dataclass
 from logging import Logger
 import os
+import random
 from typing import Dict, List, Optional
 import sys
 
@@ -15,6 +16,29 @@ from mu_map.dataset.default import MuMapDataset
 from mu_map.logging import get_logger
 
 
+def init_random_seed(seed: Optional[int] = None) -> int:
+    """
+    Set the seed for all RNGs (default python, numpy and torch).
+
+    Parameters
+    ----------
+    seed: int, optional
+        the seed to be used which is generated if not provided
+
+    Returns
+    -------
+    int
+        the randoms seed used
+    """
+    seed = seed if seed is not None else random.randint(0, 2**32 - 1)
+
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+
+    return seed
+
+
 @dataclass
 class TrainingParams:
     """