""" Implementation of random search for hyper parameter optimization. """ import json import os import random import shutil import sys from typing import Any, Callable, Dict, List, Optional import numpy as np import pandas as pd import torch from mu_map.dataset.default import MuMapDataset from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.normalization import ( MeanNormTransform, MaxNormTransform, GaussianNormTransform, ) from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform from mu_map.eval.measures import nmae, mse from mu_map.models.discriminator import Discriminator, PatchDiscriminator from mu_map.models.unet import UNet from mu_map.training.cgan import cGANTraining, DiscriminatorParams, GeneratorParams from mu_map.training.loss import WeightedLoss from mu_map.logging import get_logger class ParamSampler: """ Abstract class to sample a parameter. """ def sample(self) -> Any: """ Create a new value for a parameter. """ pass class ChoiceSampler(ParamSampler): """ Sample from a list of choices. """ def __init__(self, values: List[Any]): """ Create a new choice sampler. :param values: the list of values from which a sample is drawn. """ super().__init__() self.values = values def sample(self) -> Any: """ Retrieve a random value from the list of choices. """ idx = random.randrange(0, len(self.values)) return self.values[idx] class DependentChoiceSampler(ChoiceSampler): """ A choice sampler that depends on other parameters. """ def __init__(self, build_choices: Callable[[Any], List[Any]]): """ Create a dependent choice sampler. :param build_choices: a callable to create a list of choices depending on other parameters """ super().__init__(values=[]) self.build_choices = build_choices self.dependency_names = list(build_choices.__annotations__.keys()) def sample(self, dependencies: Dict[str, Any]) -> List[Any]: """ Sample a choice based on given dependencies. :param dependencies: a dict of name value pairs for all dependencies note that the name has to match the parameter name in the build_choices callable :return: a new sample """ self.validate_deps(dependencies) self.values = self.build_choices(**dependencies) return super().sample() def validate_deps(self, dependencies: Dict[str, Any]): """ Validate a dict of dependencies by checking if all required parameters of the build_choices callable are available. :param dependencies: dict of name value pairs for dependencies """ for name in self.dependency_names: assert ( name in dependencies.keys() ), f"Dependency {name} is missing from provided dependencies {dependencies}" class FloatIntervalSampler(ParamSampler): """ Sample a value from a float interval. """ def __init__(self, min_val: float, max_val: float): """ Create a new float interval sampler. :param min_val: the minimal value to draw :param max_val: the maximal value to draw """ super().__init__() self.min_val = min_val self.max_val = max_val def sample(self) -> float: """ Draw a new float. """ return random.uniform(self.min_val, self.max_val) class IntIntervalSampler(ParamSampler): """ Sample a value from an integer interval. """ def __init__(self, min_val: int, max_val: int): """ Create a new int interval sampler. :param min_val: the minimal value to draw :param max_val: the maximal value to draw """ super().__init__() self.min_val = min_val self.max_val = max_val def sample(self) -> int: """ Draw a new int. """ return random.randint(self.min_val, self.max_val) def scatter_correction_by_params(params: Dict[str, str]) -> bool: """ Utility function since loading json does not map to boolean values. """ return params["scatter_correction"] == "True" def normalization_by_params(params: Dict[str, str]): """ Utility function to load a normalization method. """ _norm = params["normalization"] if "Gaussian" in _norm: return GaussianNormTransform() elif "Max" in _norm: return MaxNormTransform() elif "Mean" in _norm: return MeanNormTransform() else: raise ValueError(f"Could not find normalization for param {_norm}") class RandomSearch: """ Abstract implementation of a random search. """ def __init__(self, param_sampler=Dict[str, ParamSampler]): """ Create a new random sampler. :param param_sampler: a dict of name and parameter sampler pairs all of them are sampled for a single random search run """ self.param_sampler = param_sampler def sample(self): """ Sample all parameters. This makes sure that all dependent choice samplers get their required dependencies (need to be registered in order). :return: a dictionary of name and drawn parameter pairs. """ _params = {} for key, sampler in self.param_sampler.items(): if isinstance(sampler, DependentChoiceSampler): param = sampler.sample(_params) else: param = sampler.sample() _params[key] = param return _params def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None) -> str: """ Serialize is a set of parameters to json and dump them into a file. :param params: dict of params to be serialized :param filename: optional filename where the json is dumped :return: the params as a json representation """ _params = {} for key, value in params.items(): _params[key] = str(value).replace("\n", "").replace(" ", ", ") _str = json.dumps(_params, indent=2) if filename is not None: with open(filename, mode="w") as f: f.write(_str) return _str def validate_and_make_directory(_dir: str): """ Utility method to validate that a directory exists and is empty. If is does not exist, it is created. """ if not os.path.exists(_dir): os.mkdir(_dir) return if len(os.listdir(_dir)) > 0: print(f"Directory {_dir} exists and is unexpectedly not empty!") exit(1) class RandomSearchCGAN(RandomSearch): """ Implementation of a random search for cGAN training. """ def __init__(self, iterations: int, logger=None): super().__init__({}) self.dataset_dir = "data/second" self.iterations = iterations self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.n_slices = 32 self.params = {} self.dir = "cgan_random_search" validate_and_make_directory(self.dir) self.dir_train = os.path.join(self.dir, "train_data") self.logger = ( logger if logger is not None else get_logger( logfile=os.path.join(self.dir, "search.log"), loglevel="INFO", name=RandomSearchCGAN.__name__, ) ) self.training: cGANTraining = None # dataset parameters self.param_sampler["patch_size"] = ChoiceSampler([32, 64]) self.param_sampler["patch_offset"] = ChoiceSampler([0]) self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100) self.param_sampler["scatter_correction"] = ChoiceSampler([False]) self.param_sampler["shuffle"] = ChoiceSampler([False, True]) self.param_sampler["normalization"] = ChoiceSampler( [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()] ) self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)]) # model parameters self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"]) def discriminator_conv_features(discriminator_type: str, **kwargs): if discriminator_type == "class": return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]] else: return [[32, 64, 128, 256], [64, 128, 256, 512]] self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features) self.param_sampler["generator_features"] = ChoiceSampler([ [128, 256, 512], [64, 128, 256, 512], [32, 64, 128, 256, 512], ]) # training parameters self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90]) self.param_sampler["batch_size"] = ChoiceSampler([32, 64]) self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001) self.param_sampler["lr_decay"] = ChoiceSampler([False, True]) self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1]) self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99]) self.param_sampler["criterion_dist"] = ChoiceSampler( [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")] ) self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0]) self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0]) def run(self): nmae_min = sys.maxsize for i in range(1, self.iterations + 1): self.logger.info(f"Train iteration {i}") seed = random.randint(0, 2**32 - 1) random.seed(seed) torch.manual_seed(seed) np.random.seed(seed) self.logger.info(f"Random seed for iteration {i} is {seed}") self._setup_run(i) self.training.run() nmae = self.eval_run() self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}") if nmae < nmae_min: self.logger.info(f"New best run at iteration {i}") nmae_min = nmae self._cleanup_run(i, link_best=(nmae_min == nmae)) return nmae_min def eval_run(self): self.logger.debug("Perform evaluation ...") torch.set_grad_enabled(False) weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth") self.logger.debug(f"Load weights from {weights_file}") model = self.training.generator.model.eval() model.load_state_dict(torch.load(weights_file, map_location=self.device)) transform_normalization = SequenceTransform( [self.params["normalization"], PadCropTranform(dim=3, size=self.n_slices)] ) dataset = MuMapDataset( self.dataset_dir, split_name="validation", transform_normalization=transform_normalization, scatter_correction=self.params["scatter_correction"], ) measures = {"NMAE": nmae, "MSE": mse} values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys()))) for i, (recon, mu_map) in enumerate(dataset): print( f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r", ) prediction = model(recon.unsqueeze(dim=0).to(self.device)) prediction = prediction.squeeze().cpu().numpy() mu_map = mu_map.squeeze().cpu().numpy() row = pd.DataFrame( dict( map( lambda item: (item[0], [item[1](prediction, mu_map)]), measures.items(), ) ) ) values = pd.concat((values, row), ignore_index=True) print(f" " * 100, end="\r") values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False) return values["NMAE"].mean() def _setup_run(self, iteration: int): self.logger.debug("Create directories...") validate_and_make_directory(self.dir_train) snapshot_dir = os.path.join(self.dir_train, "snapshots") validate_and_make_directory(snapshot_dir) self.params = self.sample() params_file = os.path.join(self.dir_train, "params.json") self.serialize_params(self.params, params_file) self.logger.debug(f"Store params at {params_file}") logfile = os.path.join(self.dir_train, "train.log") logger = get_logger(logfile, loglevel="INFO", name=cGANTraining.__name__) self.logger.debug(f"Training logs to {logfile}") transforms = [self.params["normalization"], self.params["pad_crop"]] transforms = list(filter(lambda transform: transform is not None, transforms)) transform_normalization = SequenceTransform(transforms) self.logger.debug(f"Init dataset ...") dataset = MuMapPatchDataset( self.dataset_dir, patches_per_image=self.params["patch_number"], patch_size=self.params["patch_size"], patch_size_z=self.n_slices, patch_offset=self.params["patch_offset"], shuffle=self.params["shuffle"], transform_normalization=transform_normalization, scatter_correction=self.params["scatter_correction"], logger=logger, ) self.logger.debug(f"Init discriminator ....") input_size = (2, self.n_slices, self.params["patch_size"], self.params["patch_size"]) if self.params["discriminator_type"] == "class": discriminator = Discriminator( in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"], ) else: discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"]) discriminator = discriminator.to(self.device) optimizer = torch.optim.Adam( discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999) ) lr_scheduler = ( torch.optim.lr_scheduler.StepLR( optimizer, step_size=self.params["lr_decay_epoch"], gamma=self.params["lr_decay_factor"], ) if self.params["lr_decay"] else None ) params_d = DiscriminatorParams( model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler ) self.logger.debug(f"Init generator ....") generator = UNet(in_channels=1, features=self.params["generator_features"]) generator = generator.to(self.device) optimizer = torch.optim.Adam( generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999) ) lr_scheduler = ( torch.optim.lr_scheduler.StepLR( optimizer, step_size=self.params["lr_decay_epoch"], gamma=self.params["lr_decay_factor"], ) if self.params["lr_decay"] else None ) params_g = GeneratorParams( model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler ) self.logger.debug(f"Init training ....") self.training = cGANTraining( epochs=self.params["epochs"], dataset=dataset, batch_size=self.params["batch_size"], device=self.device, snapshot_dir=snapshot_dir, snapshot_epoch=self.params["epochs"], logger=logger, params_generator=params_g, params_discriminator=params_d, loss_func_dist=self.params["criterion_dist"], weight_criterion_dist=self.params["weight_criterion_dist"], weight_criterion_adv=self.params["weight_criterion_adv"], ) def _cleanup_run(self, iteration: int, link_best: bool): dir_from = self.dir_train _dir = f"{iteration:0{len(str(self.iterations))}d}" dir_to = os.path.join(self.dir, _dir) self.logger.debug(f"Move iteration {iteration} from {dir_from} to {dir_to}") shutil.move(dir_from, dir_to) if link_best: linkfile = os.path.join(self.dir, "best") if os.path.exists(linkfile): os.unlink(linkfile) os.symlink(_dir, linkfile) if __name__ == "__main__": random_search = RandomSearchCGAN(iterations=10) random_search.run()