Skip to content
Snippets Groups Projects
Commit 455bed65 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement initialization of random seed

parent 102f3e46
No related merge requests found
...@@ -108,6 +108,10 @@ class Training: ...@@ -108,6 +108,10 @@ class Training:
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
import random
import sys
import numpy as np
from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import ( from mu_map.dataset.normalization import (
...@@ -155,6 +159,11 @@ if __name__ == "__main__": ...@@ -155,6 +159,11 @@ if __name__ == "__main__":
) )
# Training Args # Training Args
parser.add_argument(
"--seed",
type=int,
help="seed used for random number generation",
)
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
...@@ -232,6 +241,12 @@ if __name__ == "__main__": ...@@ -232,6 +241,12 @@ if __name__ == "__main__":
logger = get_logger_by_args(args) logger = get_logger_by_args(args)
logger.info(args) logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
model = UNet(in_channels=1, features=args.features) model = UNet(in_channels=1, features=args.features)
model = model.to(device) model = model.to(device)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment