diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index 9274ddb401467d266eb95043285f5599c91563d3..5f87abadd00bf7237ddcc4b3072b2f7566ff8eab 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -8,8 +8,9 @@ import shutil import sys from typing import Any, Callable, Dict, List, Optional -import torch +import numpy as np import pandas as pd +import torch from mu_map.dataset.default import MuMapDataset from mu_map.dataset.patches import MuMapPatchDataset @@ -238,8 +239,7 @@ class RandomSearchCGAN(RandomSearch): self.iterations = iterations self.dir = "cgan_random_search" validate_and_make_directory(self.dir) - # self.device = torch.device("cuda") - self.device = torch.device("cpu") + self.device = torch.device("cuda") self.params = {} self.dir_train = os.path.join(self.dir, "train_data") @@ -289,6 +289,8 @@ class RandomSearchCGAN(RandomSearch): seed = random.randint(0, 2**32 - 1) random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) self.logger.info(f"Random seed for iteration {i} is {seed}") self._setup_run(i) @@ -450,5 +452,5 @@ class RandomSearchCGAN(RandomSearch): if __name__ == "__main__": - random_search = RandomSearchCGAN(iterations=4) + random_search = RandomSearchCGAN(iterations=10) random_search.run()