From b7fad94bfb06a03ff3daecb88f66d2944ca896ac Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 6 Jan 2023 12:05:35 +0100
Subject: [PATCH] change parameters in random search

---
 mu_map/training/random_search.py | 48 ++++++++++++++++++++++----------
 1 file changed, 33 insertions(+), 15 deletions(-)

diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py
index fcba379..0b08c60 100644
--- a/mu_map/training/random_search.py
+++ b/mu_map/training/random_search.py
@@ -21,9 +21,9 @@ from mu_map.dataset.normalization import (
 )
 from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
 from mu_map.eval.measures import nmae, mse
-from mu_map.models.discriminator import Discriminator
+from mu_map.models.discriminator import Discriminator, PatchDiscriminator
 from mu_map.models.unet import UNet
-from mu_map.training.cgan2 import cGANTraining, DiscriminatorParams, GeneratorParams
+from mu_map.training.cgan import cGANTraining, DiscriminatorParams, GeneratorParams
 from mu_map.training.loss import WeightedLoss
 from mu_map.logging import get_logger
 
@@ -216,7 +216,7 @@ class RandomSearch:
 
 def validate_and_make_directory(_dir: str):
     """
-    Uility method to validate that a directory exists and is empty.
+    Utility method to validate that a directory exists and is empty.
     If is does not exist, it is created.
     """
     if not os.path.exists(_dir):
@@ -237,7 +237,7 @@ class RandomSearchCGAN(RandomSearch):
 
         self.dataset_dir = "data/second"
         self.iterations = iterations
-        self.device = torch.device("cuda")
+        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         self.n_slices = 32
         self.params = {}
 
@@ -256,10 +256,10 @@ class RandomSearchCGAN(RandomSearch):
         )
         self.training: cGANTraining = None
 
-        # dataset params
+        # dataset parameters
         self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
         self.param_sampler["patch_offset"] = ChoiceSampler([0])
-        self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=200)
+        self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100)
         self.param_sampler["scatter_correction"] = ChoiceSampler([False])
         self.param_sampler["shuffle"] = ChoiceSampler([False, True])
         self.param_sampler["normalization"] = ChoiceSampler(
@@ -267,10 +267,24 @@ class RandomSearchCGAN(RandomSearch):
         )
         self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)])
 
-        # training params
-        self.param_sampler["epochs"] = IntIntervalSampler(min_val=50, max_val=200)
-        self.param_sampler["batch_size"] = ChoiceSampler([64])
-        self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
+        # model parameters
+        self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"])
+        def discriminator_conv_features(discriminator_type: str, **kwargs):
+            if discriminator_type == "class":
+                return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]]
+            else:
+                return [[32, 64, 128, 256], [64, 128, 256, 512]]
+        self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features)
+        self.param_sampler["generator_features"] = ChoiceSampler([
+            [128, 256, 512],
+            [64, 128, 256, 512],
+            [32, 64, 128, 256, 512],
+        ])
+
+        # training parameters
+        self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90])
+        self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
+        self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001)
         self.param_sampler["lr_decay"] = ChoiceSampler([False, True])
         self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
         self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
@@ -379,10 +393,15 @@ class RandomSearchCGAN(RandomSearch):
         )
 
         self.logger.debug(f"Init discriminator ....")
-        discriminator = Discriminator(
-            in_channels=2, input_size=(self.n_slices, self.params["patch_size"], self.params["patch_size"])
-        )
+        input_size = (2, self.n_slices, self.params["patch_size"], self.params["patch_size"])
+        if self.params["discriminator_type"] == "class":
+            discriminator = Discriminator(
+                in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"],
+            )
+        else:
+            discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"])
         discriminator = discriminator.to(self.device)
+
         optimizer = torch.optim.Adam(
             discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
         )
@@ -400,8 +419,7 @@ class RandomSearchCGAN(RandomSearch):
         )
 
         self.logger.debug(f"Init generator ....")
-        features = [64, 128, 256, 512]
-        generator = UNet(in_channels=1, features=features)
+        generator = UNet(in_channels=1, features=self.params["generator_features"])
         generator = generator.to(self.device)
         optimizer = torch.optim.Adam(
             generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
-- 
GitLab