From 656911ccd3ee1a122e7912d1daab1f6d8be33027 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 13 Jan 2023 11:27:15 +0100 Subject: [PATCH] lots of updates to cgan random serach --- mu_map/training/random_search.py | 141 ++++++++++++++++++++++--------- 1 file changed, 103 insertions(+), 38 deletions(-) diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index e2c2e2f..275b316 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -2,11 +2,12 @@ Implementation of random search for hyper parameter optimization. """ import json +import logging import os import random import shutil import sys -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple import numpy as np import pandas as pd @@ -32,6 +33,7 @@ class ParamSampler: """ Abstract class to sample a parameter. """ + def sample(self) -> Any: """ Create a new value for a parameter. @@ -43,6 +45,7 @@ class ChoiceSampler(ParamSampler): """ Sample from a list of choices. """ + def __init__(self, values: List[Any]): """ Create a new choice sampler. @@ -64,6 +67,7 @@ class DependentChoiceSampler(ChoiceSampler): """ A choice sampler that depends on other parameters. """ + def __init__(self, build_choices: Callable[[Any], List[Any]]): """ Create a dependent choice sampler. @@ -104,6 +108,7 @@ class FloatIntervalSampler(ParamSampler): """ Sample a value from a float interval. """ + def __init__(self, min_val: float, max_val: float): """ Create a new float interval sampler. @@ -126,6 +131,7 @@ class IntIntervalSampler(ParamSampler): """ Sample a value from an integer interval. """ + def __init__(self, min_val: int, max_val: int): """ Create a new int interval sampler. @@ -170,11 +176,12 @@ class RandomSearch: """ Abstract implementation of a random search. """ + def __init__(self, param_sampler=Dict[str, ParamSampler]): """ Create a new random sampler. - :param param_sampler: a dict of name and parameter sampler pairs + :param param_sampler: a dict of name and parameter sampler pairs all of them are sampled for a single random search run """ self.param_sampler = param_sampler @@ -195,7 +202,9 @@ class RandomSearch: _params[key] = param return _params - def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None) -> str: + def serialize_params( + self, params: Dict[str, Any], filename: Optional[str] = None + ) -> str: """ Serialize is a set of parameters to json and dump them into a file. @@ -227,24 +236,42 @@ def validate_and_make_directory(_dir: str): print(f"Directory {_dir} exists and is unexpectedly not empty!") exit(1) +def init_from_dir(_dir: str, logger: logging.Logger) -> Tuple[int, float]: + last_run = os.listdir(_dir) + # filter non-directories + last_run = filter(lambda f: os.path.isdir(os.path.join(_dir, f)), last_run) + # filter symlinks + last_run = filter(lambda f: not os.path.islink(os.path.join(_dir, f)), last_run) + last_run = list(map(int, last_run)) + + if len(last_run) == 0: + return 1, sys.maxsize + + last_run = max(last_run) + + min_loss = pd.read_csv(os.path.join(_dir, "best", "measures.csv")) + min_loss = min_loss["NMAE"].mean() + logger.info(f"Continue existing random search from run {last_run} and minimal NMAE {min_loss}") + return last_run + 1, min_loss + class RandomSearchCGAN(RandomSearch): """ Implementation of a random search for cGAN training. """ + def __init__(self, iterations: int, logger=None): super().__init__({}) self.dataset_dir = "data/second" self.iterations = iterations - self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") - self.n_slices = 32 - self.params = {} - + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) self.dir = "cgan_random_search" - validate_and_make_directory(self.dir) + if not os.path.exists(self.dir): + os.mkdir(self.dir) - self.dir_train = os.path.join(self.dir, "train_data") self.logger = ( logger if logger is not None @@ -254,36 +281,57 @@ class RandomSearchCGAN(RandomSearch): name=RandomSearchCGAN.__name__, ) ) + self.n_slices = 32 + self.params = {} + + self.dir_train = os.path.join(self.dir, "train_data") + self.start, self.nmae_min = init_from_dir(self.dir, self.logger) self.training: cGANTraining = None + def patch_number(patch_size: int, **kwargs): + if patch_size == 32: + return list(range(50, 101)) + else: + return list(range(25, 51)) + # dataset parameters self.param_sampler["patch_size"] = ChoiceSampler([32, 64]) self.param_sampler["patch_offset"] = ChoiceSampler([0]) - self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100) + self.param_sampler["patch_number"] = DependentChoiceSampler(patch_number) self.param_sampler["scatter_correction"] = ChoiceSampler([False]) self.param_sampler["shuffle"] = ChoiceSampler([False, True]) self.param_sampler["normalization"] = ChoiceSampler( [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()] ) - self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)]) + self.param_sampler["pad_crop"] = ChoiceSampler( + [None, PadCropTranform(dim=3, size=self.n_slices)] + ) # model parameters self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"]) + def discriminator_conv_features(discriminator_type: str, **kwargs): if discriminator_type == "class": return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]] else: return [[32, 64, 128, 256], [64, 128, 256, 512]] - self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features) - self.param_sampler["generator_features"] = ChoiceSampler([ - [64, 128, 256, 512], - [32, 64, 128, 256, 512], - ]) + + self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler( + discriminator_conv_features + ) + self.param_sampler["generator_features"] = ChoiceSampler( + [ + [64, 128, 256, 512], + [32, 64, 128, 256, 512], + ] + ) # training parameters - self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90]) + self.param_sampler["epochs"] = ChoiceSampler([100]) + def batch_size(patch_size: int, **kwargs): return [32] if patch_size == 64 else [64] + # self.param_sampler["batch_size"] = ChoiceSampler([32, 64]) self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size) self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001) @@ -293,12 +341,20 @@ class RandomSearchCGAN(RandomSearch): self.param_sampler["criterion_dist"] = ChoiceSampler( [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")] ) - self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0]) - self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0]) + self.param_sampler["weight_criterions"] = ChoiceSampler( + [ + (100.0, 1.0), + (20.0, 1.0), + (5.0, 1.0), + (1.0, 1.0), + (1.0, 5.0), + (1.0, 20.0), + (1.0, 100.0), + ] + ) def run(self): - nmae_min = sys.maxsize - for i in range(1, self.iterations + 1): + for i in range(self.start, self.start + self.iterations + 1): self.logger.info(f"Train iteration {i}") seed = random.randint(0, 2**32 - 1) @@ -312,11 +368,11 @@ class RandomSearchCGAN(RandomSearch): nmae = self.eval_run() self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}") - if nmae < nmae_min: + if nmae < self.nmae_min: self.logger.info(f"New best run at iteration {i}") - nmae_min = nmae + self.nmae_min = nmae self._cleanup_run(i, link_best=(nmae_min == nmae)) - return nmae_min + return self.nmae_min def eval_run(self): self.logger.debug("Perform evaluation ...") @@ -362,25 +418,33 @@ class RandomSearchCGAN(RandomSearch): self.logger.debug(f"Init dataset ...") dataset = MuMapPatchDataset( - self.dataset_dir, - patches_per_image=self.params["patch_number"], - patch_size=self.params["patch_size"], - patch_size_z=self.n_slices, - patch_offset=self.params["patch_offset"], - shuffle=self.params["shuffle"], - transform_normalization=transform_normalization, - scatter_correction=self.params["scatter_correction"], - logger=logger, + self.dataset_dir, + patches_per_image=self.params["patch_number"], + patch_size=self.params["patch_size"], + patch_size_z=self.n_slices, + patch_offset=self.params["patch_offset"], + shuffle=self.params["shuffle"], + transform_normalization=transform_normalization, + scatter_correction=self.params["scatter_correction"], + logger=logger, ) self.logger.debug(f"Init discriminator ....") - input_size = (self.n_slices, self.params["patch_size"], self.params["patch_size"]) + input_size = ( + self.n_slices, + self.params["patch_size"], + self.params["patch_size"], + ) if self.params["discriminator_type"] == "class": discriminator = Discriminator( - in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"], + in_channels=2, + input_size=input_size, + conv_features=self.params["discriminator_conv_features"], ) else: - discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"]) + discriminator = PatchDiscriminator( + in_channels=2, features=self.params["discriminator_conv_features"] + ) discriminator = discriminator.to(self.device) optimizer = torch.optim.Adam( @@ -418,6 +482,7 @@ class RandomSearchCGAN(RandomSearch): model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler ) + weight_crit_dist, weight_crit_adv = self.params["weight_criterions"] self.logger.debug(f"Init training ....") self.training = cGANTraining( epochs=self.params["epochs"], @@ -430,8 +495,8 @@ class RandomSearchCGAN(RandomSearch): params_generator=params_g, params_discriminator=params_d, loss_func_dist=self.params["criterion_dist"], - weight_criterion_dist=self.params["weight_criterion_dist"], - weight_criterion_adv=self.params["weight_criterion_adv"], + weight_criterion_dist=weight_crit_dist, + weight_criterion_adv=weight_crit_adv, early_stopping=10, ) -- GitLab