From ac2a7ecc64bc365ca9cf4390dfb5b6a9060e0431 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 6 Jan 2023 12:15:01 +0100
Subject: [PATCH] fixes to random search

---
 mu_map/training/random_search.py | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py
index 0b08c60..aa39d4d 100644
--- a/mu_map/training/random_search.py
+++ b/mu_map/training/random_search.py
@@ -276,14 +276,16 @@ class RandomSearchCGAN(RandomSearch):
                 return [[32, 64, 128, 256], [64, 128, 256, 512]]
         self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features)
         self.param_sampler["generator_features"] = ChoiceSampler([
-            [128, 256, 512],
             [64, 128, 256, 512],
             [32, 64, 128, 256, 512],
         ])
 
         # training parameters
         self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90])
-        self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
+        def batch_size(patch_size: int, **kwargs):
+            return [32] if patch_size == 64 else [64]
+        # self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
+        self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size)
         self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001)
         self.param_sampler["lr_decay"] = ChoiceSampler([False, True])
         self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
@@ -393,7 +395,7 @@ class RandomSearchCGAN(RandomSearch):
         )
 
         self.logger.debug(f"Init discriminator ....")
-        input_size = (2, self.n_slices, self.params["patch_size"], self.params["patch_size"])
+        input_size = (self.n_slices, self.params["patch_size"], self.params["patch_size"])
         if self.params["discriminator_type"] == "class":
             discriminator = Discriminator(
                 in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"],
-- 
GitLab