Skip to content
Snippets Groups Projects
random_search.py 15.9 KiB
Newer Older
  • Learn to ignore specific revisions
  • """
    Implementation of random search for hyper parameter optimization.
    """
    
    import json
    import os
    import random
    import shutil
    import sys
    from typing import Any, Callable, Dict, List, Optional
    
    
    import numpy as np
    
    import pandas as pd
    
    
    from mu_map.dataset.default import MuMapDataset
    from mu_map.dataset.patches import MuMapPatchDataset
    from mu_map.dataset.normalization import (
        MeanNormTransform,
        MaxNormTransform,
        GaussianNormTransform,
    )
    from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
    
    from mu_map.eval.measures import compute_measures
    
    from mu_map.models.discriminator import Discriminator, PatchDiscriminator
    
    from mu_map.models.unet import UNet
    
    from mu_map.training.cgan import cGANTraining, DiscriminatorParams, GeneratorParams
    
    from mu_map.training.loss import WeightedLoss
    from mu_map.logging import get_logger
    
    
    class ParamSampler:
    
        """
        Abstract class to sample a parameter.
        """
    
        def sample(self) -> Any:
    
            """
            Create a new value for a parameter.
            """
    
            pass
    
    
    class ChoiceSampler(ParamSampler):
    
        """
        Sample from a list of choices.
        """
    
        def __init__(self, values: List[Any]):
    
            """
            Create a new choice sampler.
    
            :param values: the list of values from which a sample is drawn.
            """
    
            super().__init__()
            self.values = values
    
        def sample(self) -> Any:
    
            """
            Retrieve a random value from the list of choices.
            """
    
            idx = random.randrange(0, len(self.values))
            return self.values[idx]
    
    
    class DependentChoiceSampler(ChoiceSampler):
    
        """
        A choice sampler that depends on other parameters.
        """
    
        def __init__(self, build_choices: Callable[[Any], List[Any]]):
    
            """
            Create a dependent choice sampler.
    
            :param build_choices: a callable to create a list of choices depending on other parameters
            """
    
            super().__init__(values=[])
    
            self.build_choices = build_choices
            self.dependency_names = list(build_choices.__annotations__.keys())
    
        def sample(self, dependencies: Dict[str, Any]) -> List[Any]:
    
            """
            Sample a choice based on given dependencies.
    
            :param dependencies: a dict of name value pairs for all dependencies
                                 note that the name has to match the parameter name in the build_choices callable
            :return: a new sample
            """
    
            self.validate_deps(dependencies)
            self.values = self.build_choices(**dependencies)
            return super().sample()
    
    
        def validate_deps(self, dependencies: Dict[str, Any]):
            """
            Validate a dict of dependencies by checking if all required parameters of
            the build_choices callable are available.
    
            :param dependencies: dict of name value pairs for dependencies
            """
    
            for name in self.dependency_names:
                assert (
                    name in dependencies.keys()
                ), f"Dependency {name} is missing from provided dependencies {dependencies}"
    
    
    class FloatIntervalSampler(ParamSampler):
    
        """
        Sample a value from a float interval.
        """
    
        def __init__(self, min_val: float, max_val: float):
    
            """
            Create a new float interval sampler.
    
            :param min_val: the minimal value to draw
            :param max_val: the maximal value to draw
            """
    
            super().__init__()
            self.min_val = min_val
            self.max_val = max_val
    
        def sample(self) -> float:
    
            return random.uniform(self.min_val, self.max_val)
    
    
    class IntIntervalSampler(ParamSampler):
    
        """
        Sample a value from an integer interval.
        """
    
        def __init__(self, min_val: int, max_val: int):
    
            """
            Create a new int interval sampler.
    
            :param min_val: the minimal value to draw
            :param max_val: the maximal value to draw
            """
    
            super().__init__()
            self.min_val = min_val
            self.max_val = max_val
    
        def sample(self) -> int:
    
            return random.randint(self.min_val, self.max_val)
    
    
    
    def scatter_correction_by_params(params: Dict[str, str]) -> bool:
        """
        Utility function since loading json does not map to boolean values.
        """
    
        return params["scatter_correction"] == "True"
    
    
    def normalization_by_params(params: Dict[str, str]):
    
        """
        Utility function to load a normalization method.
        """
    
        _norm = params["normalization"]
        if "Gaussian" in _norm:
            return GaussianNormTransform()
        elif "Max" in _norm:
            return MaxNormTransform()
        elif "Mean" in _norm:
            return MeanNormTransform()
        else:
            raise ValueError(f"Could not find normalization for param {_norm}")
    
    
    
    class RandomSearch:
    
        """
        Abstract implementation of a random search.
        """
    
        def __init__(self, param_sampler=Dict[str, ParamSampler]):
    
            """
            Create a new random sampler.
    
            :param param_sampler: a dict of name and parameter sampler pairs 
                                  all of them are sampled for a single random search run
            """
    
            self.param_sampler = param_sampler
    
        def sample(self):
    
            """
            Sample all parameters.
            This makes sure that all dependent choice samplers get their required dependencies (need to be registered in order).
    
            :return: a dictionary of name and drawn parameter pairs.
            """
    
            _params = {}
            for key, sampler in self.param_sampler.items():
                if isinstance(sampler, DependentChoiceSampler):
                    param = sampler.sample(_params)
                else:
                    param = sampler.sample()
                _params[key] = param
            return _params
    
    
        def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None) -> str:
            """
            Serialize is a set of parameters to json and dump them into a file.
    
            :param params: dict of params to be serialized
            :param filename: optional filename where the json is dumped
            :return: the params as a json representation
            """
    
            _params = {}
            for key, value in params.items():
                _params[key] = str(value).replace("\n", "").replace("    ", ", ")
    
            _str = json.dumps(_params, indent=2)
            if filename is not None:
                with open(filename, mode="w") as f:
                    f.write(_str)
            return _str
    
    
    def validate_and_make_directory(_dir: str):
    
        Utility method to validate that a directory exists and is empty.
    
        If is does not exist, it is created.
        """
    
        if not os.path.exists(_dir):
            os.mkdir(_dir)
            return
    
        if len(os.listdir(_dir)) > 0:
            print(f"Directory {_dir} exists and is unexpectedly not empty!")
            exit(1)
    
    
    class RandomSearchCGAN(RandomSearch):
    
        """
        Implementation of a random search for cGAN training.
        """
    
        def __init__(self, iterations: int, logger=None):
            super().__init__({})
    
            self.dataset_dir = "data/second"
            self.iterations = iterations
    
            self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
            self.n_slices = 32
    
            self.dir = "cgan_random_search"
            validate_and_make_directory(self.dir)
    
    
            self.dir_train = os.path.join(self.dir, "train_data")
            self.logger = (
                logger
                if logger is not None
                else get_logger(
                    logfile=os.path.join(self.dir, "search.log"),
                    loglevel="INFO",
                    name=RandomSearchCGAN.__name__,
                )
            )
            self.training: cGANTraining = None
    
    
            # dataset parameters
    
            self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
    
            self.param_sampler["patch_offset"] = ChoiceSampler([0])
    
            self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100)
    
            self.param_sampler["scatter_correction"] = ChoiceSampler([False])
    
            self.param_sampler["shuffle"] = ChoiceSampler([False, True])
    
            self.param_sampler["normalization"] = ChoiceSampler(
                [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
            )
    
            self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)])
    
            # model parameters
            self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"])
            def discriminator_conv_features(discriminator_type: str, **kwargs):
                if discriminator_type == "class":
                    return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]]
                else:
                    return [[32, 64, 128, 256], [64, 128, 256, 512]]
            self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features)
            self.param_sampler["generator_features"] = ChoiceSampler([
                [64, 128, 256, 512],
                [32, 64, 128, 256, 512],
            ])
    
            # training parameters
            self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90])
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            def batch_size(patch_size: int, **kwargs):
                return [32] if patch_size == 64 else [64]
            # self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
            self.param_sampler["batch_size"] = DependentChoiceSampler(batch_size)
    
            self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001)
    
            self.param_sampler["lr_decay"] = ChoiceSampler([False, True])
    
            self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
            self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
            self.param_sampler["criterion_dist"] = ChoiceSampler(
                [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
            )
    
            self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
            self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])
    
    
        def run(self):
            nmae_min = sys.maxsize
            for i in range(1, self.iterations + 1):
                self.logger.info(f"Train iteration {i}")
    
                seed = random.randint(0, 2**32 - 1)
                random.seed(seed)
    
                torch.manual_seed(seed)
                np.random.seed(seed)
    
                self.logger.info(f"Random seed for iteration {i} is {seed}")
    
                self._setup_run(i)
                self.training.run()
    
                nmae = self.eval_run()
                self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
                if nmae < nmae_min:
                    self.logger.info(f"New best run at iteration {i}")
                    nmae_min = nmae
                self._cleanup_run(i, link_best=(nmae_min == nmae))
            return nmae_min
    
        def eval_run(self):
            self.logger.debug("Perform evaluation ...")
            torch.set_grad_enabled(False)
    
            weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth")
            self.logger.debug(f"Load weights from {weights_file}")
    
            model = self.training.generator.eval()
    
            model.load_state_dict(torch.load(weights_file, map_location=self.device))
    
            transform_normalization = SequenceTransform(
    
                [self.params["normalization"], PadCropTranform(dim=3, size=self.n_slices)]
    
            )
            dataset = MuMapDataset(
                self.dataset_dir,
                split_name="validation",
                transform_normalization=transform_normalization,
                scatter_correction=self.params["scatter_correction"],
            )
    
    
            values = compute_measures(dataset, model)
    
            values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False)
            return values["NMAE"].mean()
    
        def _setup_run(self, iteration: int):
            self.logger.debug("Create directories...")
            validate_and_make_directory(self.dir_train)
            snapshot_dir = os.path.join(self.dir_train, "snapshots")
            validate_and_make_directory(snapshot_dir)
    
            self.params = self.sample()
            params_file = os.path.join(self.dir_train, "params.json")
            self.serialize_params(self.params, params_file)
            self.logger.debug(f"Store params at {params_file}")
    
            logfile = os.path.join(self.dir_train, "train.log")
            logger = get_logger(logfile, loglevel="INFO", name=cGANTraining.__name__)
            self.logger.debug(f"Training logs to {logfile}")
    
            transforms = [self.params["normalization"], self.params["pad_crop"]]
            transforms = list(filter(lambda transform: transform is not None, transforms))
            transform_normalization = SequenceTransform(transforms)
    
    
            self.logger.debug(f"Init dataset ...")
            dataset = MuMapPatchDataset(
    
                    self.dataset_dir,
                    patches_per_image=self.params["patch_number"],
                    patch_size=self.params["patch_size"],
    
                    patch_size_z=self.n_slices,
    
                    patch_offset=self.params["patch_offset"],
    
                    shuffle=self.params["shuffle"],
    
                    transform_normalization=transform_normalization,
                    scatter_correction=self.params["scatter_correction"],
                    logger=logger,
    
    
            self.logger.debug(f"Init discriminator ....")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            input_size = (self.n_slices, self.params["patch_size"], self.params["patch_size"])
    
            if self.params["discriminator_type"] == "class":
                discriminator = Discriminator(
                    in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"],
                )
            else:
                discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"])
    
            discriminator = discriminator.to(self.device)
    
            optimizer = torch.optim.Adam(
                discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
            )
            lr_scheduler = (
                torch.optim.lr_scheduler.StepLR(
                    optimizer,
                    step_size=self.params["lr_decay_epoch"],
                    gamma=self.params["lr_decay_factor"],
                )
                if self.params["lr_decay"]
                else None
            )
    
            params_d = DiscriminatorParams(
    
                model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
            )
    
            self.logger.debug(f"Init generator ....")
    
            generator = UNet(in_channels=1, features=self.params["generator_features"])
    
            generator = generator.to(self.device)
            optimizer = torch.optim.Adam(
                generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
            )
            lr_scheduler = (
                torch.optim.lr_scheduler.StepLR(
                    optimizer,
                    step_size=self.params["lr_decay_epoch"],
                    gamma=self.params["lr_decay_factor"],
                )
                if self.params["lr_decay"]
                else None
            )
    
            params_g = GeneratorParams(
    
                model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
            )
    
            self.logger.debug(f"Init training ....")
            self.training = cGANTraining(
                epochs=self.params["epochs"],
    
                dataset=dataset,
                batch_size=self.params["batch_size"],
    
                device=self.device,
                snapshot_dir=snapshot_dir,
                snapshot_epoch=self.params["epochs"],
                logger=logger,
                params_generator=params_g,
                params_discriminator=params_d,
                loss_func_dist=self.params["criterion_dist"],
                weight_criterion_dist=self.params["weight_criterion_dist"],
                weight_criterion_adv=self.params["weight_criterion_adv"],
    
                early_stopping=10,
    
            )
    
        def _cleanup_run(self, iteration: int, link_best: bool):
            dir_from = self.dir_train
            _dir = f"{iteration:0{len(str(self.iterations))}d}"
            dir_to = os.path.join(self.dir, _dir)
            self.logger.debug(f"Move iteration {iteration} from {dir_from} to {dir_to}")
            shutil.move(dir_from, dir_to)
    
            if link_best:
                linkfile = os.path.join(self.dir, "best")
                if os.path.exists(linkfile):
                    os.unlink(linkfile)
                os.symlink(_dir, linkfile)
    
    
    if __name__ == "__main__":
    
        random_search = RandomSearchCGAN(iterations=50)
    
        random_search.run()