diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 989a219c5fb390aae105b6c0a94e7b21f649daf6..bbec3236882e30797f5269d32092e24b3e2dfdec 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -186,15 +186,12 @@ if __name__ == "__main__": import numpy as np from mu_map.dataset.patches import MuMapPatchDataset - from mu_map.dataset.normalization import ( - MeanNormTransform, - MaxNormTransform, - GaussianNormTransform, - ) + from mu_map.dataset.normalization import norm_choices, norm_by_str from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet from mu_map.models.discriminator import Discriminator, PatchDiscriminator + from mu_map.training.lib import init_random_seed parser = argparse.ArgumentParser( description="Train a UNet model to predict μ-maps from reconstructed images", @@ -220,7 +217,7 @@ if __name__ == "__main__": parser.add_argument( "--input_norm", type=str, - choices=["none", "mean", "max", "gaussian"], + choices=norm_choices, default="mean", help="type of normalization applied to the reconstructions", ) @@ -365,19 +362,10 @@ if __name__ == "__main__": device = torch.device(args.device) - args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) + seed = init_random_seed(args.seed) logger.info(f"Seed: {args.seed}") - random.seed(args.seed) - torch.manual_seed(args.seed) - np.random.seed(args.seed) - - transform_normalization = None - if args.input_norm == "mean": - transform_normalization = MeanNormTransform() - elif args.input_norm == "max": - transform_normalization = MaxNormTransform() - elif args.input_norm == "gaussian": - transform_normalization = GaussianNormTransform() + + transform_normalization = norm_by_str(args.input_norm) transform_normalization = SequenceTransform( [transform_normalization, PadCropTranform(dim=3, size=32)] ) diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py index 5e29323665a93b07e5ce6a274346f225f81ee4b4..df2dfb95d9f5dff506e23673b3dcfea6388c7f84 100644 --- a/mu_map/training/distance.py +++ b/mu_map/training/distance.py @@ -72,13 +72,10 @@ if __name__ == "__main__": import numpy as np from mu_map.dataset.patches import MuMapPatchDataset - from mu_map.dataset.normalization import ( - MeanNormTransform, - MaxNormTransform, - GaussianNormTransform, - ) + from mu_map.dataset.normalization import norm_choices, norm_by_str from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet + from mu_map.training.lib import init_random_seed parser = argparse.ArgumentParser( description="Train a UNet model to predict μ-maps from reconstructed scatter images", @@ -104,7 +101,7 @@ if __name__ == "__main__": parser.add_argument( "--input_norm", type=str, - choices=["none", "mean", "max", "gaussian"], + choices=norm_choices, default="mean", help="type of normalization applied to the reconstructions", ) @@ -231,20 +228,10 @@ if __name__ == "__main__": logger = get_logger_by_args(args) logger.info(args) - args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) - logger.info(f"Seed: {args.seed}") - random.seed(args.seed) - torch.manual_seed(args.seed) - np.random.seed(args.seed) - - transform_normalization = None - if args.input_norm == "mean": - transform_normalization = MeanNormTransform() - elif args.input_norm == "max": - transform_normalization = MaxNormTransform() - elif args.input_norm == "gaussian": - transform_normalization = GaussianNormTransform() + seed = init_random_seed(args.seed) + logger.info(f"Seed: {seed}") + transform_normalization = norm_by_str(args.input_norm) dataset = MuMapPatchDataset( args.dataset_dir, patches_per_image=args.number_of_patches, diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index c7b5cc446106754f9770f093de45311da82f062f..57bb8473a34ba64c2862534bb0dab9bad261a31e 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -4,6 +4,7 @@ Module functioning as a library for training related code. from dataclasses import dataclass from logging import Logger import os +import random from typing import Dict, List, Optional import sys @@ -15,6 +16,29 @@ from mu_map.dataset.default import MuMapDataset from mu_map.logging import get_logger +def init_random_seed(seed: Optional[int] = None) -> int: + """ + Set the seed for all RNGs (default python, numpy and torch). + + Parameters + ---------- + seed: int, optional + the seed to be used which is generated if not provided + + Returns + ------- + int + the randoms seed used + """ + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + return seed + + @dataclass class TrainingParams: """