From bf09641ffcaacc5f7c8a0ffa4bb202c5e21fca8d Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 5 Jan 2023 10:35:56 +0100
Subject: [PATCH] update cgan random search

---
 mu_map/training/random_search.py | 20 +++++++++-----------
 1 file changed, 9 insertions(+), 11 deletions(-)

diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py
index 5f87aba..fcba379 100644
--- a/mu_map/training/random_search.py
+++ b/mu_map/training/random_search.py
@@ -237,11 +237,13 @@ class RandomSearchCGAN(RandomSearch):
 
         self.dataset_dir = "data/second"
         self.iterations = iterations
-        self.dir = "cgan_random_search"
-        validate_and_make_directory(self.dir)
         self.device = torch.device("cuda")
+        self.n_slices = 32
         self.params = {}
 
+        self.dir = "cgan_random_search"
+        validate_and_make_directory(self.dir)
+
         self.dir_train = os.path.join(self.dir, "train_data")
         self.logger = (
             logger
@@ -258,28 +260,24 @@ class RandomSearchCGAN(RandomSearch):
         self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
         self.param_sampler["patch_offset"] = ChoiceSampler([0])
         self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=200)
-        # self.param_sampler["scatter_correction"] = ChoiceSampler([True, False])
         self.param_sampler["scatter_correction"] = ChoiceSampler([False])
         self.param_sampler["shuffle"] = ChoiceSampler([False, True])
         self.param_sampler["normalization"] = ChoiceSampler(
             [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
         )
-        self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=32)])
+        self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)])
 
         # training params
         self.param_sampler["epochs"] = IntIntervalSampler(min_val=50, max_val=200)
         self.param_sampler["batch_size"] = ChoiceSampler([64])
         self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
-        # self.param_sampler["lr"] = ChoiceSampler([0.001])
         self.param_sampler["lr_decay"] = ChoiceSampler([False, True])
         self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
         self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
         self.param_sampler["criterion_dist"] = ChoiceSampler(
             [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
         )
-        # self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0)
         self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
-        # self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0)
         self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])
 
     def run(self):
@@ -305,17 +303,16 @@ class RandomSearchCGAN(RandomSearch):
         return nmae_min
 
     def eval_run(self):
-        return random.randint(0, 200)
         self.logger.debug("Perform evaluation ...")
         torch.set_grad_enabled(False)
 
         weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth")
         self.logger.debug(f"Load weights from {weights_file}")
-        model = self.training.params_g.model.eval()
+        model = self.training.generator.model.eval()
         model.load_state_dict(torch.load(weights_file, map_location=self.device))
 
         transform_normalization = SequenceTransform(
-            [self.params["normalization"], PadCropTranform(dim=3, size=32)]
+            [self.params["normalization"], PadCropTranform(dim=3, size=self.n_slices)]
         )
         dataset = MuMapDataset(
             self.dataset_dir,
@@ -373,6 +370,7 @@ class RandomSearchCGAN(RandomSearch):
                 self.dataset_dir,
                 patches_per_image=self.params["patch_number"],
                 patch_size=self.params["patch_size"],
+                patch_size_z=self.n_slices,
                 patch_offset=self.params["patch_offset"],
                 shuffle=self.params["shuffle"],
                 transform_normalization=transform_normalization,
@@ -382,7 +380,7 @@ class RandomSearchCGAN(RandomSearch):
 
         self.logger.debug(f"Init discriminator ....")
         discriminator = Discriminator(
-            in_channels=2, input_size=self.params["patch_size"]
+            in_channels=2, input_size=(self.n_slices, self.params["patch_size"], self.params["patch_size"])
         )
         discriminator = discriminator.to(self.device)
         optimizer = torch.optim.Adam(
-- 
GitLab