Skip to content
Snippets Groups Projects
random_search.py 13.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • import json
    import os
    import random
    import shutil
    import sys
    from typing import Any, Callable, Dict, List, Optional
    
    import torch
    import pandas as pd
    
    from mu_map.dataset.default import MuMapDataset
    from mu_map.dataset.patches import MuMapPatchDataset
    from mu_map.dataset.normalization import (
        MeanNormTransform,
        MaxNormTransform,
        GaussianNormTransform,
    )
    from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
    from mu_map.eval.measures import nmae, mse
    from mu_map.models.discriminator import Discriminator
    from mu_map.models.unet import UNet
    from mu_map.training.cgan import cGANTraining, TrainingParams
    from mu_map.training.loss import WeightedLoss
    from mu_map.logging import get_logger
    
    
    class ParamSampler:
        def sample(self) -> Any:
            pass
    
    
    class ChoiceSampler(ParamSampler):
        def __init__(self, values: List[Any]):
            super().__init__()
            self.values = values
    
        def sample(self) -> Any:
            idx = random.randrange(0, len(self.values))
            return self.values[idx]
    
    
    class DependentChoiceSampler(ChoiceSampler):
        def __init__(self, build_choices: Callable[[Any], List[Any]]):
            super().__init__(values=[])
    
            self.build_choices = build_choices
            self.dependency_names = list(build_choices.__annotations__.keys())
    
        def sample(self, dependencies: Dict[str, Any]) -> List[Any]:
            self.validate_deps(dependencies)
            self.values = self.build_choices(**dependencies)
            return super().sample()
    
        def validate_deps(self, dependencies: Dict[str, Any]) -> bool:
            for name in self.dependency_names:
                assert (
                    name in dependencies.keys()
                ), f"Dependency {name} is missing from provided dependencies {dependencies}"
    
    
    class FloatIntervalSampler(ParamSampler):
        def __init__(self, min_val: float, max_val: float):
            super().__init__()
            self.min_val = min_val
            self.max_val = max_val
    
        def sample(self) -> float:
            return random.uniform(self.min_val, self.max_val)
    
    
    class IntIntervalSampler(ParamSampler):
        def __init__(self, min_val: int, max_val: int):
            super().__init__()
            self.min_val = min_val
            self.max_val = max_val
    
        def sample(self) -> int:
            return random.randint(self.min_val, self.max_val)
    
    
    
    def scatter_correction_by_params(params: Dict[str, str]):
        return params["scatter_correction"] == "True"
    
    
    def normalization_by_params(params: Dict[str, str]):
        _norm = params["normalization"]
        if "Gaussian" in _norm:
            return GaussianNormTransform()
        elif "Max" in _norm:
            return MaxNormTransform()
        elif "Mean" in _norm:
            return MeanNormTransform()
        else:
            raise ValueError(f"Could not find normalization for param {_norm}")
    
    
    
    class RandomSearch:
        def __init__(self, param_sampler=Dict[str, ParamSampler]):
            self.param_sampler = param_sampler
    
        def sample(self):
            _params = {}
            for key, sampler in self.param_sampler.items():
                if isinstance(sampler, DependentChoiceSampler):
                    param = sampler.sample(_params)
                else:
                    param = sampler.sample()
                _params[key] = param
            return _params
    
        def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None):
            _params = {}
            for key, value in params.items():
                _params[key] = str(value).replace("\n", "").replace("    ", ", ")
    
            _str = json.dumps(_params, indent=2)
            if filename is not None:
                with open(filename, mode="w") as f:
                    f.write(_str)
            return _str
    
    
    def validate_and_make_directory(_dir: str):
        if not os.path.exists(_dir):
            os.mkdir(_dir)
            return
    
        if len(os.listdir(_dir)) > 0:
            print(f"Directory {_dir} exists and is unexpectedly not empty!")
            exit(1)
    
    
    class RandomSearchCGAN(RandomSearch):
        def __init__(self, iterations: int, logger=None):
            super().__init__({})
    
            self.dataset_dir = "data/second"
            self.iterations = iterations
            self.dir = "cgan_random_search"
            validate_and_make_directory(self.dir)
            self.device = torch.device("cuda")
            self.params = {}
    
            self.dir_train = os.path.join(self.dir, "train_data")
            self.logger = (
                logger
                if logger is not None
                else get_logger(
                    logfile=os.path.join(self.dir, "search.log"),
                    loglevel="INFO",
                    name=RandomSearchCGAN.__name__,
                )
            )
            self.training: cGANTraining = None
    
            # dataset params
            self.param_sampler["patch_size"] = ChoiceSampler([32])
    
            self.param_sampler["patch_offset"] = ChoiceSampler([0])
            self.param_sampler["patch_number"] = ChoiceSampler([100])
    
            # self.param_sampler["scatter_correction"] = ChoiceSampler([True, False])
            self.param_sampler["scatter_correction"] = ChoiceSampler([True])
            self.param_sampler["shuffle"] = ChoiceSampler([False, True])
    
            self.param_sampler["normalization"] = ChoiceSampler(
                [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
            )
    
            self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=32)])
    
    
            # training params
            self.param_sampler["epochs"] = ChoiceSampler([100])
            self.param_sampler["batch_size"] = ChoiceSampler([64])
    
            # self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
            self.param_sampler["lr"] = ChoiceSampler([0.001])
            self.param_sampler["lr_decay"] = ChoiceSampler([False])
    
            self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
            self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
            self.param_sampler["criterion_dist"] = ChoiceSampler(
                [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
            )
    
            # self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0)
            self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
            # self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0)
            self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])
    
    
        def run(self):
            nmae_min = sys.maxsize
            for i in range(1, self.iterations + 1):
                self.logger.info(f"Train iteration {i}")
    
                seed = random.randint(0, 2**32 - 1)
                random.seed(seed)
                self.logger.info(f"Random seed for iteration {i} is {seed}")
    
                self._setup_run(i)
                self.training.run()
    
                nmae = self.eval_run()
                self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
                if nmae < nmae_min:
                    self.logger.info(f"New best run at iteration {i}")
                    nmae_min = nmae
                self._cleanup_run(i, link_best=(nmae_min == nmae))
            return nmae_min
    
        def eval_run(self):
            self.logger.debug("Perform evaluation ...")
            torch.set_grad_enabled(False)
    
            weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth")
            self.logger.debug(f"Load weights from {weights_file}")
            model = self.training.params_g.model.eval()
            model.load_state_dict(torch.load(weights_file, map_location=self.device))
    
            transform_normalization = SequenceTransform(
                [self.params["normalization"], PadCropTranform(dim=3, size=32)]
            )
            dataset = MuMapDataset(
                self.dataset_dir,
                split_name="validation",
                transform_normalization=transform_normalization,
                scatter_correction=self.params["scatter_correction"],
            )
    
            measures = {"NMAE": nmae, "MSE": mse}
            values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
            for i, (recon, mu_map) in enumerate(dataset):
                print(
                    f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}",
                    end="\r",
                )
                prediction = model(recon.unsqueeze(dim=0).to(self.device))
                prediction = prediction.squeeze().cpu().numpy()
                mu_map = mu_map.squeeze().cpu().numpy()
    
                row = pd.DataFrame(
                    dict(
                        map(
                            lambda item: (item[0], [item[1](prediction, mu_map)]),
                            measures.items(),
                        )
                    )
                )
                values = pd.concat((values, row), ignore_index=True)
            print(f" " * 100, end="\r")
    
            values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False)
            return values["NMAE"].mean()
    
        def _setup_run(self, iteration: int):
            self.logger.debug("Create directories...")
            validate_and_make_directory(self.dir_train)
            snapshot_dir = os.path.join(self.dir_train, "snapshots")
            validate_and_make_directory(snapshot_dir)
    
            self.params = self.sample()
            params_file = os.path.join(self.dir_train, "params.json")
            self.serialize_params(self.params, params_file)
            self.logger.debug(f"Store params at {params_file}")
    
            logfile = os.path.join(self.dir_train, "train.log")
            logger = get_logger(logfile, loglevel="INFO", name=cGANTraining.__name__)
            self.logger.debug(f"Training logs to {logfile}")
    
            transforms = [self.params["normalization"], self.params["pad_crop"]]
            transforms = list(filter(lambda transform: transform is not None, transforms))
            transform_normalization = SequenceTransform(transforms)
    
            self.logger.debug(f"Init data loaders ....")
            data_loaders = {}
            for split in ["train", "validation"]:
                dataset = MuMapPatchDataset(
                    self.dataset_dir,
                    patches_per_image=self.params["patch_number"],
                    patch_size=self.params["patch_size"],
                    patch_offset=self.params["patch_offset"],
                    shuffle=self.params["shuffle"] if split == "train" else False,
                    split_name=split,
                    transform_normalization=transform_normalization,
                    scatter_correction=self.params["scatter_correction"],
                    logger=logger,
                )
                data_loader = torch.utils.data.DataLoader(
                    dataset=dataset,
                    batch_size=self.params["batch_size"],
                    shuffle=split == "train",
                    pin_memory=True,
                    num_workers=1,
                )
                data_loaders[split] = data_loader
    
            self.logger.debug(f"Init discriminator ....")
            discriminator = Discriminator(
                in_channels=2, input_size=self.params["patch_size"]
            )
            discriminator = discriminator.to(self.device)
            optimizer = torch.optim.Adam(
                discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
            )
            lr_scheduler = (
                torch.optim.lr_scheduler.StepLR(
                    optimizer,
                    step_size=self.params["lr_decay_epoch"],
                    gamma=self.params["lr_decay_factor"],
                )
                if self.params["lr_decay"]
                else None
            )
            params_d = TrainingParams(
                model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
            )
    
            self.logger.debug(f"Init generator ....")
            features = [64, 128, 256, 512]
            generator = UNet(in_channels=1, features=features)
            generator = generator.to(self.device)
            optimizer = torch.optim.Adam(
                generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
            )
            lr_scheduler = (
                torch.optim.lr_scheduler.StepLR(
                    optimizer,
                    step_size=self.params["lr_decay_epoch"],
                    gamma=self.params["lr_decay_factor"],
                )
                if self.params["lr_decay"]
                else None
            )
            params_g = TrainingParams(
                model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
            )
    
            self.logger.debug(f"Init training ....")
            self.training = cGANTraining(
                data_loaders=data_loaders,
                epochs=self.params["epochs"],
                device=self.device,
                snapshot_dir=snapshot_dir,
                snapshot_epoch=self.params["epochs"],
                logger=logger,
                params_generator=params_g,
                params_discriminator=params_d,
                loss_func_dist=self.params["criterion_dist"],
                weight_criterion_dist=self.params["weight_criterion_dist"],
                weight_criterion_adv=self.params["weight_criterion_adv"],
            )
    
        def _cleanup_run(self, iteration: int, link_best: bool):
            dir_from = self.dir_train
            _dir = f"{iteration:0{len(str(self.iterations))}d}"
            dir_to = os.path.join(self.dir, _dir)
            self.logger.debug(f"Move iteration {iteration} from {dir_from} to {dir_to}")
            shutil.move(dir_from, dir_to)
    
            if link_best:
                linkfile = os.path.join(self.dir, "best")
                if os.path.exists(linkfile):
                    os.unlink(linkfile)
                os.symlink(_dir, linkfile)
    
    
    if __name__ == "__main__":
    
        random_search = RandomSearchCGAN(iterations=6)
    
        random_search.run()