From 1e16732e6ac560bf4f8227fc4df560da331575d6 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Wed, 18 Jan 2023 13:42:29 +0100 Subject: [PATCH] introduce function to init random seeds in training lib --- mu_map/training/cgan.py | 24 ++++++------------------ mu_map/training/distance.py | 25 ++++++------------------- mu_map/training/lib.py | 24 ++++++++++++++++++++++++ 3 files changed, 36 insertions(+), 37 deletions(-) diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 989a219..bbec323 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -186,15 +186,12 @@ if __name__ == "__main__": import numpy as np from mu_map.dataset.patches import MuMapPatchDataset - from mu_map.dataset.normalization import ( - MeanNormTransform, - MaxNormTransform, - GaussianNormTransform, - ) + from mu_map.dataset.normalization import norm_choices, norm_by_str from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet from mu_map.models.discriminator import Discriminator, PatchDiscriminator + from mu_map.training.lib import init_random_seed parser = argparse.ArgumentParser( description="Train a UNet model to predict μ-maps from reconstructed images", @@ -220,7 +217,7 @@ if __name__ == "__main__": parser.add_argument( "--input_norm", type=str, - choices=["none", "mean", "max", "gaussian"], + choices=norm_choices, default="mean", help="type of normalization applied to the reconstructions", ) @@ -365,19 +362,10 @@ if __name__ == "__main__": device = torch.device(args.device) - args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + seed = init_random_seed(args.seed) logger.info(f"Seed: {args.seed}") - random.seed(args.seed) - torch.manual_seed(args.seed) - np.random.seed(args.seed) - - transform_normalization = None - if args.input_norm == "mean": - transform_normalization = MeanNormTransform() - elif args.input_norm == "max": - transform_normalization = MaxNormTransform() - elif args.input_norm == "gaussian": - transform_normalization = GaussianNormTransform() + + transform_normalization = norm_by_str(args.input_norm) transform_normalization = SequenceTransform( [transform_normalization, PadCropTranform(dim=3, size=32)] ) diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py index 5e29323..df2dfb9 100644 --- a/mu_map/training/distance.py +++ b/mu_map/training/distance.py @@ -72,13 +72,10 @@ if __name__ == "__main__": import numpy as np from mu_map.dataset.patches import MuMapPatchDataset - from mu_map.dataset.normalization import ( - MeanNormTransform, - MaxNormTransform, - GaussianNormTransform, - ) + from mu_map.dataset.normalization import norm_choices, norm_by_str from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet + from mu_map.training.lib import init_random_seed parser = argparse.ArgumentParser( description="Train a UNet model to predict μ-maps from reconstructed scatter images", @@ -104,7 +101,7 @@ if __name__ == "__main__": parser.add_argument( "--input_norm", type=str, - choices=["none", "mean", "max", "gaussian"], + choices=norm_choices, default="mean", help="type of normalization applied to the reconstructions", ) @@ -231,20 +228,10 @@ if __name__ == "__main__": logger = get_logger_by_args(args) logger.info(args) - args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) - logger.info(f"Seed: {args.seed}") - random.seed(args.seed) - torch.manual_seed(args.seed) - np.random.seed(args.seed) - - transform_normalization = None - if args.input_norm == "mean": - transform_normalization = MeanNormTransform() - elif args.input_norm == "max": - transform_normalization = MaxNormTransform() - elif args.input_norm == "gaussian": - transform_normalization = GaussianNormTransform() + seed = init_random_seed(args.seed) + logger.info(f"Seed: {seed}") + transform_normalization = norm_by_str(args.input_norm) dataset = MuMapPatchDataset( args.dataset_dir, patches_per_image=args.number_of_patches, diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index c7b5cc4..57bb847 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -4,6 +4,7 @@ Module functioning as a library for training related code. from dataclasses import dataclass from logging import Logger import os +import random from typing import Dict, List, Optional import sys @@ -15,6 +16,29 @@ from mu_map.dataset.default import MuMapDataset from mu_map.logging import get_logger +def init_random_seed(seed: Optional[int] = None) -> int: + """ + Set the seed for all RNGs (default python, numpy and torch). + + Parameters + ---------- + seed: int, optional + the seed to be used which is generated if not provided + + Returns + ------- + int + the randoms seed used + """ + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + return seed + + @dataclass class TrainingParams: """ -- GitLab