Skip to content
Snippets Groups Projects
Commit 1e16732e authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

introduce function to init random seeds in training lib

parent 93f4da3e
No related branches found
No related tags found
No related merge requests found
...@@ -186,15 +186,12 @@ if __name__ == "__main__": ...@@ -186,15 +186,12 @@ if __name__ == "__main__":
import numpy as np import numpy as np
from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import ( from mu_map.dataset.normalization import norm_choices, norm_by_str
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.dataset.transform import PadCropTranform, SequenceTransform
from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
from mu_map.models.discriminator import Discriminator, PatchDiscriminator from mu_map.models.discriminator import Discriminator, PatchDiscriminator
from mu_map.training.lib import init_random_seed
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed images", description="Train a UNet model to predict μ-maps from reconstructed images",
...@@ -220,7 +217,7 @@ if __name__ == "__main__": ...@@ -220,7 +217,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--input_norm", "--input_norm",
type=str, type=str,
choices=["none", "mean", "max", "gaussian"], choices=norm_choices,
default="mean", default="mean",
help="type of normalization applied to the reconstructions", help="type of normalization applied to the reconstructions",
) )
...@@ -365,19 +362,10 @@ if __name__ == "__main__": ...@@ -365,19 +362,10 @@ if __name__ == "__main__":
device = torch.device(args.device) device = torch.device(args.device)
args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) seed = init_random_seed(args.seed)
logger.info(f"Seed: {args.seed}") logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed) transform_normalization = norm_by_str(args.input_norm)
np.random.seed(args.seed)
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
transform_normalization = SequenceTransform( transform_normalization = SequenceTransform(
[transform_normalization, PadCropTranform(dim=3, size=32)] [transform_normalization, PadCropTranform(dim=3, size=32)]
) )
......
...@@ -72,13 +72,10 @@ if __name__ == "__main__": ...@@ -72,13 +72,10 @@ if __name__ == "__main__":
import numpy as np import numpy as np
from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import ( from mu_map.dataset.normalization import norm_choices, norm_by_str
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
from mu_map.training.lib import init_random_seed
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images", description="Train a UNet model to predict μ-maps from reconstructed scatter images",
...@@ -104,7 +101,7 @@ if __name__ == "__main__": ...@@ -104,7 +101,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--input_norm", "--input_norm",
type=str, type=str,
choices=["none", "mean", "max", "gaussian"], choices=norm_choices,
default="mean", default="mean",
help="type of normalization applied to the reconstructions", help="type of normalization applied to the reconstructions",
) )
...@@ -231,20 +228,10 @@ if __name__ == "__main__": ...@@ -231,20 +228,10 @@ if __name__ == "__main__":
logger = get_logger_by_args(args) logger = get_logger_by_args(args)
logger.info(args) logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) seed = init_random_seed(args.seed)
logger.info(f"Seed: {args.seed}") logger.info(f"Seed: {seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
transform_normalization = norm_by_str(args.input_norm)
dataset = MuMapPatchDataset( dataset = MuMapPatchDataset(
args.dataset_dir, args.dataset_dir,
patches_per_image=args.number_of_patches, patches_per_image=args.number_of_patches,
......
...@@ -4,6 +4,7 @@ Module functioning as a library for training related code. ...@@ -4,6 +4,7 @@ Module functioning as a library for training related code.
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger from logging import Logger
import os import os
import random
from typing import Dict, List, Optional from typing import Dict, List, Optional
import sys import sys
...@@ -15,6 +16,29 @@ from mu_map.dataset.default import MuMapDataset ...@@ -15,6 +16,29 @@ from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger from mu_map.logging import get_logger
def init_random_seed(seed: Optional[int] = None) -> int:
"""
Set the seed for all RNGs (default python, numpy and torch).
Parameters
----------
seed: int, optional
the seed to be used which is generated if not provided
Returns
-------
int
the randoms seed used
"""
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
return seed
@dataclass @dataclass
class TrainingParams: class TrainingParams:
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment