Skip to content
Snippets Groups Projects
Commit 656911cc authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

lots of updates to cgan random serach

parent 0d2cde7e
No related branches found
No related tags found
No related merge requests found
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
Implementation of random search for hyper parameter optimization. Implementation of random search for hyper parameter optimization.
""" """
import json import json
import logging
import os import os
import random import random
import shutil import shutil
import sys import sys
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
...@@ -32,6 +33,7 @@ class ParamSampler: ...@@ -32,6 +33,7 @@ class ParamSampler:
""" """
Abstract class to sample a parameter. Abstract class to sample a parameter.
""" """
def sample(self) -> Any: def sample(self) -> Any:
""" """
Create a new value for a parameter. Create a new value for a parameter.
...@@ -43,6 +45,7 @@ class ChoiceSampler(ParamSampler): ...@@ -43,6 +45,7 @@ class ChoiceSampler(ParamSampler):
""" """
Sample from a list of choices. Sample from a list of choices.
""" """
def __init__(self, values: List[Any]): def __init__(self, values: List[Any]):
""" """
Create a new choice sampler. Create a new choice sampler.
...@@ -64,6 +67,7 @@ class DependentChoiceSampler(ChoiceSampler): ...@@ -64,6 +67,7 @@ class DependentChoiceSampler(ChoiceSampler):
""" """
A choice sampler that depends on other parameters. A choice sampler that depends on other parameters.
""" """
def __init__(self, build_choices: Callable[[Any], List[Any]]): def __init__(self, build_choices: Callable[[Any], List[Any]]):
""" """
Create a dependent choice sampler. Create a dependent choice sampler.
...@@ -104,6 +108,7 @@ class FloatIntervalSampler(ParamSampler): ...@@ -104,6 +108,7 @@ class FloatIntervalSampler(ParamSampler):
""" """
Sample a value from a float interval. Sample a value from a float interval.
""" """
def __init__(self, min_val: float, max_val: float): def __init__(self, min_val: float, max_val: float):
""" """
Create a new float interval sampler. Create a new float interval sampler.
...@@ -126,6 +131,7 @@ class IntIntervalSampler(ParamSampler): ...@@ -126,6 +131,7 @@ class IntIntervalSampler(ParamSampler):
""" """
Sample a value from an integer interval. Sample a value from an integer interval.
""" """
def __init__(self, min_val: int, max_val: int): def __init__(self, min_val: int, max_val: int):
""" """
Create a new int interval sampler. Create a new int interval sampler.
...@@ -170,11 +176,12 @@ class RandomSearch: ...@@ -170,11 +176,12 @@ class RandomSearch:
""" """
Abstract implementation of a random search. Abstract implementation of a random search.
""" """
def __init__(self, param_sampler=Dict[str, ParamSampler]): def __init__(self, param_sampler=Dict[str, ParamSampler]):
""" """
Create a new random sampler. Create a new random sampler.
:param param_sampler: a dict of name and parameter sampler pairs :param param_sampler: a dict of name and parameter sampler pairs
all of them are sampled for a single random search run all of them are sampled for a single random search run
""" """
self.param_sampler = param_sampler self.param_sampler = param_sampler
...@@ -195,7 +202,9 @@ class RandomSearch: ...@@ -195,7 +202,9 @@ class RandomSearch:
_params[key] = param _params[key] = param
return _params return _params
def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None) -> str: def serialize_params(
self, params: Dict[str, Any], filename: Optional[str] = None
) -> str:
""" """
Serialize is a set of parameters to json and dump them into a file. Serialize is a set of parameters to json and dump them into a file.
...@@ -227,24 +236,42 @@ def validate_and_make_directory(_dir: str): ...@@ -227,24 +236,42 @@ def validate_and_make_directory(_dir: str):
print(f"Directory {_dir} exists and is unexpectedly not empty!") print(f"Directory {_dir} exists and is unexpectedly not empty!")
exit(1) exit(1)
def init_from_dir(_dir: str, logger: logging.Logger) -> Tuple[int, float]:
last_run = os.listdir(_dir)
# filter non-directories
last_run = filter(lambda f: os.path.isdir(os.path.join(_dir, f)), last_run)
# filter symlinks
last_run = filter(lambda f: not os.path.islink(os.path.join(_dir, f)), last_run)
last_run = list(map(int, last_run))
if len(last_run) == 0:
return 1, sys.maxsize
last_run = max(last_run)
min_loss = pd.read_csv(os.path.join(_dir, "best", "measures.csv"))
min_loss = min_loss["NMAE"].mean()
logger.info(f"Continue existing random search from run {last_run} and minimal NMAE {min_loss}")
return last_run + 1, min_loss
class RandomSearchCGAN(RandomSearch): class RandomSearchCGAN(RandomSearch):
""" """
Implementation of a random search for cGAN training. Implementation of a random search for cGAN training.
""" """
def __init__(self, iterations: int, logger=None): def __init__(self, iterations: int, logger=None):
super().__init__({}) super().__init__({})
self.dataset_dir = "data/second" self.dataset_dir = "data/second"
self.iterations = iterations self.iterations = iterations
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.device = (
self.n_slices = 32 torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.params = {} )
self.dir = "cgan_random_search" self.dir = "cgan_random_search"
validate_and_make_directory(self.dir) if not os.path.exists(self.dir):
os.mkdir(self.dir)
self.dir_train = os.path.join(self.dir, "train_data")
self.logger = ( self.logger = (
logger logger
if logger is not None if logger is not None
...@@ -254,36 +281,57 @@ class RandomSearchCGAN(RandomSearch): ...@@ -254,36 +281,57 @@ class RandomSearchCGAN(RandomSearch):
name=RandomSearchCGAN.__name__, name=RandomSearchCGAN.__name__,
) )
) )
self.n_slices = 32
self.params = {}
self.dir_train = os.path.join(self.dir, "train_data")
self.start, self.nmae_min = init_from_dir(self.dir, self.logger)
self.training: cGANTraining = None self.training: cGANTraining = None
def patch_number(patch_size: int, **kwargs):
if patch_size == 32:
return list(range(50, 101))
else:
return list(range(25, 51))
# dataset parameters # dataset parameters
self.param_sampler["patch_size"] = ChoiceSampler([32, 64]) self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
self.param_sampler["patch_offset"] = ChoiceSampler([0]) self.param_sampler["patch_offset"] = ChoiceSampler([0])
self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100) self.param_sampler["patch_number"] = DependentChoiceSampler(patch_number)
self.param_sampler["scatter_correction"] = ChoiceSampler([False]) self.param_sampler["scatter_correction"] = ChoiceSampler([False])
self.param_sampler["shuffle"] = ChoiceSampler([False, True]) self.param_sampler["shuffle"] = ChoiceSampler([False, True])
self.param_sampler["normalization"] = ChoiceSampler( self.param_sampler["normalization"] = ChoiceSampler(
[MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()] [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
) )
self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)]) self.param_sampler["pad_crop"] = ChoiceSampler(
[None, PadCropTranform(dim=3, size=self.n_slices)]
)
# model parameters # model parameters
self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"]) self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"])
def discriminator_conv_features(discriminator_type: str, **kwargs): def discriminator_conv_features(discriminator_type: str, **kwargs):
if discriminator_type == "class": if discriminator_type == "class":
return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]] return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]]
else: else:
return [[32, 64, 128, 256], [64, 128, 256, 512]] return [[32, 64, 128, 256], [64, 128, 256, 512]]
self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features)
self.param_sampler["generator_features"] = ChoiceSampler([ self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(
[64, 128, 256, 512], discriminator_conv_features
[32, 64, 128, 256, 512], )
]) self.param_sampler["generator_features"] = ChoiceSampler(
[
[64, 128, 256, 512],
[32, 64, 128, 256, 512],
]
)
# training parameters # training parameters
self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90]) self.param_sampler["epochs"] = ChoiceSampler([100])
def batch_size(patch_size: int, **kwargs): def batch_size(patch_size: int, **kwargs):
return [32] if patch_size == 64 else [64] return [32] if patch_size == 64 else [64]
# self.param_sampler["batch_size"] = ChoiceSampler([32, 64]) # self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size) self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size)
self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001) self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001)
...@@ -293,12 +341,20 @@ class RandomSearchCGAN(RandomSearch): ...@@ -293,12 +341,20 @@ class RandomSearchCGAN(RandomSearch):
self.param_sampler["criterion_dist"] = ChoiceSampler( self.param_sampler["criterion_dist"] = ChoiceSampler(
[WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")] [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
) )
self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0]) self.param_sampler["weight_criterions"] = ChoiceSampler(
self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0]) [
(100.0, 1.0),
(20.0, 1.0),
(5.0, 1.0),
(1.0, 1.0),
(1.0, 5.0),
(1.0, 20.0),
(1.0, 100.0),
]
)
def run(self): def run(self):
nmae_min = sys.maxsize for i in range(self.start, self.start + self.iterations + 1):
for i in range(1, self.iterations + 1):
self.logger.info(f"Train iteration {i}") self.logger.info(f"Train iteration {i}")
seed = random.randint(0, 2**32 - 1) seed = random.randint(0, 2**32 - 1)
...@@ -312,11 +368,11 @@ class RandomSearchCGAN(RandomSearch): ...@@ -312,11 +368,11 @@ class RandomSearchCGAN(RandomSearch):
nmae = self.eval_run() nmae = self.eval_run()
self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}") self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
if nmae < nmae_min: if nmae < self.nmae_min:
self.logger.info(f"New best run at iteration {i}") self.logger.info(f"New best run at iteration {i}")
nmae_min = nmae self.nmae_min = nmae
self._cleanup_run(i, link_best=(nmae_min == nmae)) self._cleanup_run(i, link_best=(nmae_min == nmae))
return nmae_min return self.nmae_min
def eval_run(self): def eval_run(self):
self.logger.debug("Perform evaluation ...") self.logger.debug("Perform evaluation ...")
...@@ -362,25 +418,33 @@ class RandomSearchCGAN(RandomSearch): ...@@ -362,25 +418,33 @@ class RandomSearchCGAN(RandomSearch):
self.logger.debug(f"Init dataset ...") self.logger.debug(f"Init dataset ...")
dataset = MuMapPatchDataset( dataset = MuMapPatchDataset(
self.dataset_dir, self.dataset_dir,
patches_per_image=self.params["patch_number"], patches_per_image=self.params["patch_number"],
patch_size=self.params["patch_size"], patch_size=self.params["patch_size"],
patch_size_z=self.n_slices, patch_size_z=self.n_slices,
patch_offset=self.params["patch_offset"], patch_offset=self.params["patch_offset"],
shuffle=self.params["shuffle"], shuffle=self.params["shuffle"],
transform_normalization=transform_normalization, transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"], scatter_correction=self.params["scatter_correction"],
logger=logger, logger=logger,
) )
self.logger.debug(f"Init discriminator ....") self.logger.debug(f"Init discriminator ....")
input_size = (self.n_slices, self.params["patch_size"], self.params["patch_size"]) input_size = (
self.n_slices,
self.params["patch_size"],
self.params["patch_size"],
)
if self.params["discriminator_type"] == "class": if self.params["discriminator_type"] == "class":
discriminator = Discriminator( discriminator = Discriminator(
in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"], in_channels=2,
input_size=input_size,
conv_features=self.params["discriminator_conv_features"],
) )
else: else:
discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"]) discriminator = PatchDiscriminator(
in_channels=2, features=self.params["discriminator_conv_features"]
)
discriminator = discriminator.to(self.device) discriminator = discriminator.to(self.device)
optimizer = torch.optim.Adam( optimizer = torch.optim.Adam(
...@@ -418,6 +482,7 @@ class RandomSearchCGAN(RandomSearch): ...@@ -418,6 +482,7 @@ class RandomSearchCGAN(RandomSearch):
model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
) )
weight_crit_dist, weight_crit_adv = self.params["weight_criterions"]
self.logger.debug(f"Init training ....") self.logger.debug(f"Init training ....")
self.training = cGANTraining( self.training = cGANTraining(
epochs=self.params["epochs"], epochs=self.params["epochs"],
...@@ -430,8 +495,8 @@ class RandomSearchCGAN(RandomSearch): ...@@ -430,8 +495,8 @@ class RandomSearchCGAN(RandomSearch):
params_generator=params_g, params_generator=params_g,
params_discriminator=params_d, params_discriminator=params_d,
loss_func_dist=self.params["criterion_dist"], loss_func_dist=self.params["criterion_dist"],
weight_criterion_dist=self.params["weight_criterion_dist"], weight_criterion_dist=weight_crit_dist,
weight_criterion_adv=self.params["weight_criterion_adv"], weight_criterion_adv=weight_crit_adv,
early_stopping=10, early_stopping=10,
) )
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment