Newer
Older
"""
Implementation of random search for hyper parameter optimization.
"""
import json
import os
import random
import shutil
import sys
from typing import Any, Callable, Dict, List, Optional
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
from mu_map.eval.measures import compute_measures
from mu_map.models.discriminator import Discriminator, PatchDiscriminator
from mu_map.models.unet import UNet
from mu_map.training.cgan import cGANTraining, DiscriminatorParams, GeneratorParams
from mu_map.training.loss import WeightedLoss
from mu_map.logging import get_logger
class ParamSampler:
"""
Abstract class to sample a parameter.
"""
"""
Create a new value for a parameter.
"""
pass
class ChoiceSampler(ParamSampler):
"""
Sample from a list of choices.
"""
def __init__(self, values: List[Any]):
"""
Create a new choice sampler.
:param values: the list of values from which a sample is drawn.
"""
super().__init__()
self.values = values
def sample(self) -> Any:
"""
Retrieve a random value from the list of choices.
"""
idx = random.randrange(0, len(self.values))
return self.values[idx]
class DependentChoiceSampler(ChoiceSampler):
"""
A choice sampler that depends on other parameters.
"""
def __init__(self, build_choices: Callable[[Any], List[Any]]):
"""
Create a dependent choice sampler.
:param build_choices: a callable to create a list of choices depending on other parameters
"""
super().__init__(values=[])
self.build_choices = build_choices
self.dependency_names = list(build_choices.__annotations__.keys())
def sample(self, dependencies: Dict[str, Any]) -> List[Any]:
"""
Sample a choice based on given dependencies.
:param dependencies: a dict of name value pairs for all dependencies
note that the name has to match the parameter name in the build_choices callable
:return: a new sample
"""
self.validate_deps(dependencies)
self.values = self.build_choices(**dependencies)
return super().sample()
def validate_deps(self, dependencies: Dict[str, Any]):
"""
Validate a dict of dependencies by checking if all required parameters of
the build_choices callable are available.
:param dependencies: dict of name value pairs for dependencies
"""
for name in self.dependency_names:
assert (
name in dependencies.keys()
), f"Dependency {name} is missing from provided dependencies {dependencies}"
class FloatIntervalSampler(ParamSampler):
"""
Sample a value from a float interval.
"""
def __init__(self, min_val: float, max_val: float):
"""
Create a new float interval sampler.
:param min_val: the minimal value to draw
:param max_val: the maximal value to draw
"""
super().__init__()
self.min_val = min_val
self.max_val = max_val
def sample(self) -> float:
"""
Draw a new float.
"""
return random.uniform(self.min_val, self.max_val)
class IntIntervalSampler(ParamSampler):
"""
Sample a value from an integer interval.
"""
def __init__(self, min_val: int, max_val: int):
"""
Create a new int interval sampler.
:param min_val: the minimal value to draw
:param max_val: the maximal value to draw
"""
super().__init__()
self.min_val = min_val
self.max_val = max_val
def sample(self) -> int:
"""
Draw a new int.
"""
return random.randint(self.min_val, self.max_val)
def scatter_correction_by_params(params: Dict[str, str]) -> bool:
"""
Utility function since loading json does not map to boolean values.
"""
return params["scatter_correction"] == "True"
def normalization_by_params(params: Dict[str, str]):
"""
Utility function to load a normalization method.
"""
_norm = params["normalization"]
if "Gaussian" in _norm:
return GaussianNormTransform()
elif "Max" in _norm:
return MaxNormTransform()
elif "Mean" in _norm:
return MeanNormTransform()
else:
raise ValueError(f"Could not find normalization for param {_norm}")
"""
Abstract implementation of a random search.
"""
def __init__(self, param_sampler=Dict[str, ParamSampler]):
"""
Create a new random sampler.
:param param_sampler: a dict of name and parameter sampler pairs
all of them are sampled for a single random search run
"""
self.param_sampler = param_sampler
def sample(self):
"""
Sample all parameters.
This makes sure that all dependent choice samplers get their required dependencies (need to be registered in order).
:return: a dictionary of name and drawn parameter pairs.
"""
_params = {}
for key, sampler in self.param_sampler.items():
if isinstance(sampler, DependentChoiceSampler):
param = sampler.sample(_params)
else:
param = sampler.sample()
_params[key] = param
return _params
def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None) -> str:
"""
Serialize is a set of parameters to json and dump them into a file.
:param params: dict of params to be serialized
:param filename: optional filename where the json is dumped
:return: the params as a json representation
"""
_params = {}
for key, value in params.items():
_params[key] = str(value).replace("\n", "").replace(" ", ", ")
_str = json.dumps(_params, indent=2)
if filename is not None:
with open(filename, mode="w") as f:
f.write(_str)
return _str
def validate_and_make_directory(_dir: str):
Utility method to validate that a directory exists and is empty.
If is does not exist, it is created.
"""
if not os.path.exists(_dir):
os.mkdir(_dir)
return
if len(os.listdir(_dir)) > 0:
print(f"Directory {_dir} exists and is unexpectedly not empty!")
exit(1)
class RandomSearchCGAN(RandomSearch):
"""
Implementation of a random search for cGAN training.
"""
def __init__(self, iterations: int, logger=None):
super().__init__({})
self.dataset_dir = "data/second"
self.iterations = iterations
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.dir = "cgan_random_search"
validate_and_make_directory(self.dir)
self.dir_train = os.path.join(self.dir, "train_data")
self.logger = (
logger
if logger is not None
else get_logger(
logfile=os.path.join(self.dir, "search.log"),
loglevel="INFO",
name=RandomSearchCGAN.__name__,
)
)
self.training: cGANTraining = None
self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
self.param_sampler["patch_offset"] = ChoiceSampler([0])
self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100)
self.param_sampler["scatter_correction"] = ChoiceSampler([False])
self.param_sampler["shuffle"] = ChoiceSampler([False, True])
self.param_sampler["normalization"] = ChoiceSampler(
[MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
)
self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)])
# model parameters
self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"])
def discriminator_conv_features(discriminator_type: str, **kwargs):
if discriminator_type == "class":
return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]]
else:
return [[32, 64, 128, 256], [64, 128, 256, 512]]
self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features)
self.param_sampler["generator_features"] = ChoiceSampler([
[64, 128, 256, 512],
[32, 64, 128, 256, 512],
])
# training parameters
self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90])
def batch_size(patch_size: int, **kwargs):
return [32] if patch_size == 64 else [64]
# self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size)
self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001)
self.param_sampler["lr_decay"] = ChoiceSampler([False, True])
self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
self.param_sampler["criterion_dist"] = ChoiceSampler(
[WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
)
self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])
def run(self):
nmae_min = sys.maxsize
for i in range(1, self.iterations + 1):
self.logger.info(f"Train iteration {i}")
seed = random.randint(0, 2**32 - 1)
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
self.logger.info(f"Random seed for iteration {i} is {seed}")
self._setup_run(i)
self.training.run()
nmae = self.eval_run()
self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
if nmae < nmae_min:
self.logger.info(f"New best run at iteration {i}")
nmae_min = nmae
self._cleanup_run(i, link_best=(nmae_min == nmae))
return nmae_min
def eval_run(self):
self.logger.debug("Perform evaluation ...")
torch.set_grad_enabled(False)
weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth")
self.logger.debug(f"Load weights from {weights_file}")
model.load_state_dict(torch.load(weights_file, map_location=self.device))
transform_normalization = SequenceTransform(
[self.params["normalization"], PadCropTranform(dim=3, size=self.n_slices)]
)
dataset = MuMapDataset(
self.dataset_dir,
split_name="validation",
transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"],
)
values = compute_measures(dataset, model)
values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False)
return values["NMAE"].mean()
def _setup_run(self, iteration: int):
self.logger.debug("Create directories...")
validate_and_make_directory(self.dir_train)
snapshot_dir = os.path.join(self.dir_train, "snapshots")
validate_and_make_directory(snapshot_dir)
self.params = self.sample()
params_file = os.path.join(self.dir_train, "params.json")
self.serialize_params(self.params, params_file)
self.logger.debug(f"Store params at {params_file}")
logfile = os.path.join(self.dir_train, "train.log")
logger = get_logger(logfile, loglevel="INFO", name=cGANTraining.__name__)
self.logger.debug(f"Training logs to {logfile}")
transforms = [self.params["normalization"], self.params["pad_crop"]]
transforms = list(filter(lambda transform: transform is not None, transforms))
transform_normalization = SequenceTransform(transforms)
self.logger.debug(f"Init dataset ...")
dataset = MuMapPatchDataset(
self.dataset_dir,
patches_per_image=self.params["patch_number"],
patch_size=self.params["patch_size"],
patch_offset=self.params["patch_offset"],
shuffle=self.params["shuffle"],
transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"],
logger=logger,
self.logger.debug(f"Init discriminator ....")
input_size = (self.n_slices, self.params["patch_size"], self.params["patch_size"])
if self.params["discriminator_type"] == "class":
discriminator = Discriminator(
in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"],
)
else:
discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"])
discriminator = discriminator.to(self.device)
optimizer = torch.optim.Adam(
discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
)
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.params["lr_decay_epoch"],
gamma=self.params["lr_decay_factor"],
)
if self.params["lr_decay"]
else None
)
params_d = DiscriminatorParams(
model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
self.logger.debug(f"Init generator ....")
generator = UNet(in_channels=1, features=self.params["generator_features"])
generator = generator.to(self.device)
optimizer = torch.optim.Adam(
generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
)
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.params["lr_decay_epoch"],
gamma=self.params["lr_decay_factor"],
)
if self.params["lr_decay"]
else None
)
params_g = GeneratorParams(
model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
self.logger.debug(f"Init training ....")
self.training = cGANTraining(
epochs=self.params["epochs"],
dataset=dataset,
batch_size=self.params["batch_size"],
device=self.device,
snapshot_dir=snapshot_dir,
snapshot_epoch=self.params["epochs"],
logger=logger,
params_generator=params_g,
params_discriminator=params_d,
loss_func_dist=self.params["criterion_dist"],
weight_criterion_dist=self.params["weight_criterion_dist"],
weight_criterion_adv=self.params["weight_criterion_adv"],
)
def _cleanup_run(self, iteration: int, link_best: bool):
dir_from = self.dir_train
_dir = f"{iteration:0{len(str(self.iterations))}d}"
dir_to = os.path.join(self.dir, _dir)
self.logger.debug(f"Move iteration {iteration} from {dir_from} to {dir_to}")
shutil.move(dir_from, dir_to)
if link_best:
linkfile = os.path.join(self.dir, "best")
if os.path.exists(linkfile):
os.unlink(linkfile)
os.symlink(_dir, linkfile)
if __name__ == "__main__":
random_search = RandomSearchCGAN(iterations=50)