From ac2a7ecc64bc365ca9cf4390dfb5b6a9060e0431 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 6 Jan 2023 12:15:01 +0100 Subject: [PATCH] fixes to random search --- mu_map/training/random_search.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index 0b08c60..aa39d4d 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -276,14 +276,16 @@ class RandomSearchCGAN(RandomSearch): return [[32, 64, 128, 256], [64, 128, 256, 512]] self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features) self.param_sampler["generator_features"] = ChoiceSampler([ - [128, 256, 512], [64, 128, 256, 512], [32, 64, 128, 256, 512], ]) # training parameters self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90]) - self.param_sampler["batch_size"] = ChoiceSampler([32, 64]) + def batch_size(patch_size: int, **kwargs): + return [32] if patch_size == 64 else [64] + # self.param_sampler["batch_size"] = ChoiceSampler([32, 64]) + self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size) self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001) self.param_sampler["lr_decay"] = ChoiceSampler([False, True]) self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1]) @@ -393,7 +395,7 @@ class RandomSearchCGAN(RandomSearch): ) self.logger.debug(f"Init discriminator ....") - input_size = (2, self.n_slices, self.params["patch_size"], self.params["patch_size"]) + input_size = (self.n_slices, self.params["patch_size"], self.params["patch_size"]) if self.params["discriminator_type"] == "class": discriminator = Discriminator( in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"], -- GitLab