diff --git a/mu_map/training/default.py b/mu_map/training/default.py index e328f453b8f406e37d462bf27c4b77f212dd95d2..8a8e4e42683a484cd471f7b279d178b4ea999883 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -108,6 +108,10 @@ class Training: if __name__ == "__main__": import argparse + import random + import sys + + import numpy as np from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.normalization import ( @@ -155,6 +159,11 @@ if __name__ == "__main__": ) # Training Args + parser.add_argument( + "--seed", + type=int, + help="seed used for random number generation", + ) parser.add_argument( "--batch_size", type=int, @@ -232,6 +241,12 @@ if __name__ == "__main__": logger = get_logger_by_args(args) logger.info(args) + args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + logger.info(f"Seed: {args.seed}") + random.seed(args.seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + model = UNet(in_channels=1, features=args.features) model = model.to(device)