From 455bed65eff2f11192f95a01f20f9191b191a138 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 4 Oct 2022 10:00:12 +0200
Subject: [PATCH] implement initialization of random seed

---
 mu_map/training/default.py | 15 +++++++++++++++
 1 file changed, 15 insertions(+)

diff --git a/mu_map/training/default.py b/mu_map/training/default.py
index e328f45..8a8e4e4 100644
--- a/mu_map/training/default.py
+++ b/mu_map/training/default.py
@@ -108,6 +108,10 @@ class Training:
 
 if __name__ == "__main__":
     import argparse
+    import random
+    import sys
+
+    import numpy as np
 
     from mu_map.dataset.patches import MuMapPatchDataset
     from mu_map.dataset.normalization import (
@@ -155,6 +159,11 @@ if __name__ == "__main__":
     )
 
     # Training Args
+    parser.add_argument(
+        "--seed",
+        type=int,
+        help="seed used for random number generation",
+    )
     parser.add_argument(
         "--batch_size",
         type=int,
@@ -232,6 +241,12 @@ if __name__ == "__main__":
     logger = get_logger_by_args(args)
     logger.info(args)
 
+    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
+    logger.info(f"Seed: {args.seed}")
+    random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    np.random.seed(args.seed)
+
     model = UNet(in_channels=1, features=args.features)
     model = model.to(device)
 
-- 
GitLab