From 656911ccd3ee1a122e7912d1daab1f6d8be33027 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 13 Jan 2023 11:27:15 +0100
Subject: [PATCH] lots of updates to cgan random serach

---
 mu_map/training/random_search.py | 141 ++++++++++++++++++++++---------
 1 file changed, 103 insertions(+), 38 deletions(-)

diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py
index e2c2e2f..275b316 100644
--- a/mu_map/training/random_search.py
+++ b/mu_map/training/random_search.py
@@ -2,11 +2,12 @@
 Implementation of random search for hyper parameter optimization.
 """
 import json
+import logging
 import os
 import random
 import shutil
 import sys
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import numpy as np
 import pandas as pd
@@ -32,6 +33,7 @@ class ParamSampler:
     """
     Abstract class to sample a parameter.
     """
+
     def sample(self) -> Any:
         """
         Create a new value for a parameter.
@@ -43,6 +45,7 @@ class ChoiceSampler(ParamSampler):
     """
     Sample from a list of choices.
     """
+
     def __init__(self, values: List[Any]):
         """
         Create a new choice sampler.
@@ -64,6 +67,7 @@ class DependentChoiceSampler(ChoiceSampler):
     """
     A choice sampler that depends on other parameters.
     """
+
     def __init__(self, build_choices: Callable[[Any], List[Any]]):
         """
         Create a dependent choice sampler.
@@ -104,6 +108,7 @@ class FloatIntervalSampler(ParamSampler):
     """
     Sample a value from a float interval.
     """
+
     def __init__(self, min_val: float, max_val: float):
         """
         Create a new float interval sampler.
@@ -126,6 +131,7 @@ class IntIntervalSampler(ParamSampler):
     """
     Sample a value from an integer interval.
     """
+
     def __init__(self, min_val: int, max_val: int):
         """
         Create a new int interval sampler.
@@ -170,11 +176,12 @@ class RandomSearch:
     """
     Abstract implementation of a random search.
     """
+
     def __init__(self, param_sampler=Dict[str, ParamSampler]):
         """
         Create a new random sampler.
 
-        :param param_sampler: a dict of name and parameter sampler pairs 
+        :param param_sampler: a dict of name and parameter sampler pairs
                               all of them are sampled for a single random search run
         """
         self.param_sampler = param_sampler
@@ -195,7 +202,9 @@ class RandomSearch:
             _params[key] = param
         return _params
 
-    def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None) -> str:
+    def serialize_params(
+        self, params: Dict[str, Any], filename: Optional[str] = None
+    ) -> str:
         """
         Serialize is a set of parameters to json and dump them into a file.
 
@@ -227,24 +236,42 @@ def validate_and_make_directory(_dir: str):
         print(f"Directory {_dir} exists and is unexpectedly not empty!")
         exit(1)
 
+def init_from_dir(_dir: str, logger: logging.Logger) -> Tuple[int, float]:
+    last_run = os.listdir(_dir)
+    # filter non-directories
+    last_run = filter(lambda f: os.path.isdir(os.path.join(_dir, f)), last_run)
+    # filter symlinks
+    last_run = filter(lambda f: not os.path.islink(os.path.join(_dir, f)), last_run)
+    last_run = list(map(int, last_run))
+
+    if len(last_run) == 0:
+        return 1, sys.maxsize
+
+    last_run = max(last_run)
+
+    min_loss = pd.read_csv(os.path.join(_dir, "best", "measures.csv"))
+    min_loss = min_loss["NMAE"].mean()
+    logger.info(f"Continue existing random search from run {last_run} and minimal NMAE {min_loss}")
+    return last_run + 1, min_loss
+
 
 class RandomSearchCGAN(RandomSearch):
     """
     Implementation of a random search for cGAN training.
     """
+
     def __init__(self, iterations: int, logger=None):
         super().__init__({})
 
         self.dataset_dir = "data/second"
         self.iterations = iterations
-        self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
-        self.n_slices = 32
-        self.params = {}
-
+        self.device = (
+            torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+        )
         self.dir = "cgan_random_search"
-        validate_and_make_directory(self.dir)
+        if not os.path.exists(self.dir):
+            os.mkdir(self.dir)
 
-        self.dir_train = os.path.join(self.dir, "train_data")
         self.logger = (
             logger
             if logger is not None
@@ -254,36 +281,57 @@ class RandomSearchCGAN(RandomSearch):
                 name=RandomSearchCGAN.__name__,
             )
         )
+        self.n_slices = 32
+        self.params = {}
+
+        self.dir_train = os.path.join(self.dir, "train_data")
+        self.start, self.nmae_min = init_from_dir(self.dir, self.logger)
         self.training: cGANTraining = None
 
+        def patch_number(patch_size: int, **kwargs):
+            if patch_size == 32:
+                return list(range(50, 101))
+            else:
+                return list(range(25, 51))
+
         # dataset parameters
         self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
         self.param_sampler["patch_offset"] = ChoiceSampler([0])
-        self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100)
+        self.param_sampler["patch_number"] = DependentChoiceSampler(patch_number)
         self.param_sampler["scatter_correction"] = ChoiceSampler([False])
         self.param_sampler["shuffle"] = ChoiceSampler([False, True])
         self.param_sampler["normalization"] = ChoiceSampler(
             [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
         )
-        self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)])
+        self.param_sampler["pad_crop"] = ChoiceSampler(
+            [None, PadCropTranform(dim=3, size=self.n_slices)]
+        )
 
         # model parameters
         self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"])
+
         def discriminator_conv_features(discriminator_type: str, **kwargs):
             if discriminator_type == "class":
                 return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]]
             else:
                 return [[32, 64, 128, 256], [64, 128, 256, 512]]
-        self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features)
-        self.param_sampler["generator_features"] = ChoiceSampler([
-            [64, 128, 256, 512],
-            [32, 64, 128, 256, 512],
-        ])
+
+        self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(
+            discriminator_conv_features
+        )
+        self.param_sampler["generator_features"] = ChoiceSampler(
+            [
+                [64, 128, 256, 512],
+                [32, 64, 128, 256, 512],
+            ]
+        )
 
         # training parameters
-        self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90])
+        self.param_sampler["epochs"] = ChoiceSampler([100])
+
         def batch_size(patch_size: int, **kwargs):
             return [32] if patch_size == 64 else [64]
+
         # self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
         self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size)
         self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001)
@@ -293,12 +341,20 @@ class RandomSearchCGAN(RandomSearch):
         self.param_sampler["criterion_dist"] = ChoiceSampler(
             [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
         )
-        self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
-        self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])
+        self.param_sampler["weight_criterions"] = ChoiceSampler(
+            [
+                (100.0, 1.0),
+                (20.0, 1.0),
+                (5.0, 1.0),
+                (1.0, 1.0),
+                (1.0, 5.0),
+                (1.0, 20.0),
+                (1.0, 100.0),
+            ]
+        )
 
     def run(self):
-        nmae_min = sys.maxsize
-        for i in range(1, self.iterations + 1):
+        for i in range(self.start, self.start + self.iterations + 1):
             self.logger.info(f"Train iteration {i}")
 
             seed = random.randint(0, 2**32 - 1)
@@ -312,11 +368,11 @@ class RandomSearchCGAN(RandomSearch):
 
             nmae = self.eval_run()
             self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
-            if nmae < nmae_min:
+            if nmae < self.nmae_min:
                 self.logger.info(f"New best run at iteration {i}")
-                nmae_min = nmae
+                self.nmae_min = nmae
             self._cleanup_run(i, link_best=(nmae_min == nmae))
-        return nmae_min
+        return self.nmae_min
 
     def eval_run(self):
         self.logger.debug("Perform evaluation ...")
@@ -362,25 +418,33 @@ class RandomSearchCGAN(RandomSearch):
 
         self.logger.debug(f"Init dataset ...")
         dataset = MuMapPatchDataset(
-                self.dataset_dir,
-                patches_per_image=self.params["patch_number"],
-                patch_size=self.params["patch_size"],
-                patch_size_z=self.n_slices,
-                patch_offset=self.params["patch_offset"],
-                shuffle=self.params["shuffle"],
-                transform_normalization=transform_normalization,
-                scatter_correction=self.params["scatter_correction"],
-                logger=logger,
+            self.dataset_dir,
+            patches_per_image=self.params["patch_number"],
+            patch_size=self.params["patch_size"],
+            patch_size_z=self.n_slices,
+            patch_offset=self.params["patch_offset"],
+            shuffle=self.params["shuffle"],
+            transform_normalization=transform_normalization,
+            scatter_correction=self.params["scatter_correction"],
+            logger=logger,
         )
 
         self.logger.debug(f"Init discriminator ....")
-        input_size = (self.n_slices, self.params["patch_size"], self.params["patch_size"])
+        input_size = (
+            self.n_slices,
+            self.params["patch_size"],
+            self.params["patch_size"],
+        )
         if self.params["discriminator_type"] == "class":
             discriminator = Discriminator(
-                in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"],
+                in_channels=2,
+                input_size=input_size,
+                conv_features=self.params["discriminator_conv_features"],
             )
         else:
-            discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"])
+            discriminator = PatchDiscriminator(
+                in_channels=2, features=self.params["discriminator_conv_features"]
+            )
         discriminator = discriminator.to(self.device)
 
         optimizer = torch.optim.Adam(
@@ -418,6 +482,7 @@ class RandomSearchCGAN(RandomSearch):
             model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
         )
 
+        weight_crit_dist, weight_crit_adv = self.params["weight_criterions"]
         self.logger.debug(f"Init training ....")
         self.training = cGANTraining(
             epochs=self.params["epochs"],
@@ -430,8 +495,8 @@ class RandomSearchCGAN(RandomSearch):
             params_generator=params_g,
             params_discriminator=params_d,
             loss_func_dist=self.params["criterion_dist"],
-            weight_criterion_dist=self.params["weight_criterion_dist"],
-            weight_criterion_adv=self.params["weight_criterion_adv"],
+            weight_criterion_dist=weight_crit_dist,
+            weight_criterion_adv=weight_crit_adv,
             early_stopping=10,
         )
 
-- 
GitLab