diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index 9274ddb401467d266eb95043285f5599c91563d3..f2afd1046ebb316fce2c06093402b489e7fb2dca 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -8,8 +8,9 @@ import shutil import sys from typing import Any, Callable, Dict, List, Optional -import torch +import numpy as np import pandas as pd +import torch from mu_map.dataset.default import MuMapDataset from mu_map.dataset.patches import MuMapPatchDataset @@ -289,6 +290,8 @@ class RandomSearchCGAN(RandomSearch): seed = random.randint(0, 2**32 - 1) random.seed(seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) self.logger.info(f"Random seed for iteration {i} is {seed}") self._setup_run(i)