Skip to content
Snippets Groups Projects
Commit 656911cc authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

lots of updates to cgan random serach

parent 0d2cde7e
No related branches found
No related tags found
No related merge requests found
......@@ -2,11 +2,12 @@
Implementation of random search for hyper parameter optimization.
"""
import json
import logging
import os
import random
import shutil
import sys
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
......@@ -32,6 +33,7 @@ class ParamSampler:
"""
Abstract class to sample a parameter.
"""
def sample(self) -> Any:
"""
Create a new value for a parameter.
......@@ -43,6 +45,7 @@ class ChoiceSampler(ParamSampler):
"""
Sample from a list of choices.
"""
def __init__(self, values: List[Any]):
"""
Create a new choice sampler.
......@@ -64,6 +67,7 @@ class DependentChoiceSampler(ChoiceSampler):
"""
A choice sampler that depends on other parameters.
"""
def __init__(self, build_choices: Callable[[Any], List[Any]]):
"""
Create a dependent choice sampler.
......@@ -104,6 +108,7 @@ class FloatIntervalSampler(ParamSampler):
"""
Sample a value from a float interval.
"""
def __init__(self, min_val: float, max_val: float):
"""
Create a new float interval sampler.
......@@ -126,6 +131,7 @@ class IntIntervalSampler(ParamSampler):
"""
Sample a value from an integer interval.
"""
def __init__(self, min_val: int, max_val: int):
"""
Create a new int interval sampler.
......@@ -170,11 +176,12 @@ class RandomSearch:
"""
Abstract implementation of a random search.
"""
def __init__(self, param_sampler=Dict[str, ParamSampler]):
"""
Create a new random sampler.
:param param_sampler: a dict of name and parameter sampler pairs
:param param_sampler: a dict of name and parameter sampler pairs
all of them are sampled for a single random search run
"""
self.param_sampler = param_sampler
......@@ -195,7 +202,9 @@ class RandomSearch:
_params[key] = param
return _params
def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None) -> str:
def serialize_params(
self, params: Dict[str, Any], filename: Optional[str] = None
) -> str:
"""
Serialize is a set of parameters to json and dump them into a file.
......@@ -227,24 +236,42 @@ def validate_and_make_directory(_dir: str):
print(f"Directory {_dir} exists and is unexpectedly not empty!")
exit(1)
def init_from_dir(_dir: str, logger: logging.Logger) -> Tuple[int, float]:
last_run = os.listdir(_dir)
# filter non-directories
last_run = filter(lambda f: os.path.isdir(os.path.join(_dir, f)), last_run)
# filter symlinks
last_run = filter(lambda f: not os.path.islink(os.path.join(_dir, f)), last_run)
last_run = list(map(int, last_run))
if len(last_run) == 0:
return 1, sys.maxsize
last_run = max(last_run)
min_loss = pd.read_csv(os.path.join(_dir, "best", "measures.csv"))
min_loss = min_loss["NMAE"].mean()
logger.info(f"Continue existing random search from run {last_run} and minimal NMAE {min_loss}")
return last_run + 1, min_loss
class RandomSearchCGAN(RandomSearch):
"""
Implementation of a random search for cGAN training.
"""
def __init__(self, iterations: int, logger=None):
super().__init__({})
self.dataset_dir = "data/second"
self.iterations = iterations
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.n_slices = 32
self.params = {}
self.device = (
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
)
self.dir = "cgan_random_search"
validate_and_make_directory(self.dir)
if not os.path.exists(self.dir):
os.mkdir(self.dir)
self.dir_train = os.path.join(self.dir, "train_data")
self.logger = (
logger
if logger is not None
......@@ -254,36 +281,57 @@ class RandomSearchCGAN(RandomSearch):
name=RandomSearchCGAN.__name__,
)
)
self.n_slices = 32
self.params = {}
self.dir_train = os.path.join(self.dir, "train_data")
self.start, self.nmae_min = init_from_dir(self.dir, self.logger)
self.training: cGANTraining = None
def patch_number(patch_size: int, **kwargs):
if patch_size == 32:
return list(range(50, 101))
else:
return list(range(25, 51))
# dataset parameters
self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
self.param_sampler["patch_offset"] = ChoiceSampler([0])
self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100)
self.param_sampler["patch_number"] = DependentChoiceSampler(patch_number)
self.param_sampler["scatter_correction"] = ChoiceSampler([False])
self.param_sampler["shuffle"] = ChoiceSampler([False, True])
self.param_sampler["normalization"] = ChoiceSampler(
[MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
)
self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)])
self.param_sampler["pad_crop"] = ChoiceSampler(
[None, PadCropTranform(dim=3, size=self.n_slices)]
)
# model parameters
self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"])
def discriminator_conv_features(discriminator_type: str, **kwargs):
if discriminator_type == "class":
return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]]
else:
return [[32, 64, 128, 256], [64, 128, 256, 512]]
self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features)
self.param_sampler["generator_features"] = ChoiceSampler([
[64, 128, 256, 512],
[32, 64, 128, 256, 512],
])
self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(
discriminator_conv_features
)
self.param_sampler["generator_features"] = ChoiceSampler(
[
[64, 128, 256, 512],
[32, 64, 128, 256, 512],
]
)
# training parameters
self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90])
self.param_sampler["epochs"] = ChoiceSampler([100])
def batch_size(patch_size: int, **kwargs):
return [32] if patch_size == 64 else [64]
# self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size)
self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001)
......@@ -293,12 +341,20 @@ class RandomSearchCGAN(RandomSearch):
self.param_sampler["criterion_dist"] = ChoiceSampler(
[WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
)
self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])
self.param_sampler["weight_criterions"] = ChoiceSampler(
[
(100.0, 1.0),
(20.0, 1.0),
(5.0, 1.0),
(1.0, 1.0),
(1.0, 5.0),
(1.0, 20.0),
(1.0, 100.0),
]
)
def run(self):
nmae_min = sys.maxsize
for i in range(1, self.iterations + 1):
for i in range(self.start, self.start + self.iterations + 1):
self.logger.info(f"Train iteration {i}")
seed = random.randint(0, 2**32 - 1)
......@@ -312,11 +368,11 @@ class RandomSearchCGAN(RandomSearch):
nmae = self.eval_run()
self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
if nmae < nmae_min:
if nmae < self.nmae_min:
self.logger.info(f"New best run at iteration {i}")
nmae_min = nmae
self.nmae_min = nmae
self._cleanup_run(i, link_best=(nmae_min == nmae))
return nmae_min
return self.nmae_min
def eval_run(self):
self.logger.debug("Perform evaluation ...")
......@@ -362,25 +418,33 @@ class RandomSearchCGAN(RandomSearch):
self.logger.debug(f"Init dataset ...")
dataset = MuMapPatchDataset(
self.dataset_dir,
patches_per_image=self.params["patch_number"],
patch_size=self.params["patch_size"],
patch_size_z=self.n_slices,
patch_offset=self.params["patch_offset"],
shuffle=self.params["shuffle"],
transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"],
logger=logger,
self.dataset_dir,
patches_per_image=self.params["patch_number"],
patch_size=self.params["patch_size"],
patch_size_z=self.n_slices,
patch_offset=self.params["patch_offset"],
shuffle=self.params["shuffle"],
transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"],
logger=logger,
)
self.logger.debug(f"Init discriminator ....")
input_size = (self.n_slices, self.params["patch_size"], self.params["patch_size"])
input_size = (
self.n_slices,
self.params["patch_size"],
self.params["patch_size"],
)
if self.params["discriminator_type"] == "class":
discriminator = Discriminator(
in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"],
in_channels=2,
input_size=input_size,
conv_features=self.params["discriminator_conv_features"],
)
else:
discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"])
discriminator = PatchDiscriminator(
in_channels=2, features=self.params["discriminator_conv_features"]
)
discriminator = discriminator.to(self.device)
optimizer = torch.optim.Adam(
......@@ -418,6 +482,7 @@ class RandomSearchCGAN(RandomSearch):
model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
weight_crit_dist, weight_crit_adv = self.params["weight_criterions"]
self.logger.debug(f"Init training ....")
self.training = cGANTraining(
epochs=self.params["epochs"],
......@@ -430,8 +495,8 @@ class RandomSearchCGAN(RandomSearch):
params_generator=params_g,
params_discriminator=params_d,
loss_func_dist=self.params["criterion_dist"],
weight_criterion_dist=self.params["weight_criterion_dist"],
weight_criterion_adv=self.params["weight_criterion_adv"],
weight_criterion_dist=weight_crit_dist,
weight_criterion_adv=weight_crit_adv,
early_stopping=10,
)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment