Skip to content
Snippets Groups Projects
random_search.py 13.05 KiB
import json
import os
import random
import shutil
import sys
from typing import Any, Callable, Dict, List, Optional

import torch
import pandas as pd

from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
    MeanNormTransform,
    MaxNormTransform,
    GaussianNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
from mu_map.eval.measures import nmae, mse
from mu_map.models.discriminator import Discriminator
from mu_map.models.unet import UNet
from mu_map.training.cgan import cGANTraining, TrainingParams
from mu_map.training.loss import WeightedLoss
from mu_map.logging import get_logger


class ParamSampler:
    def sample(self) -> Any:
        pass


class ChoiceSampler(ParamSampler):
    def __init__(self, values: List[Any]):
        super().__init__()
        self.values = values

    def sample(self) -> Any:
        idx = random.randrange(0, len(self.values))
        return self.values[idx]


class DependentChoiceSampler(ChoiceSampler):
    def __init__(self, build_choices: Callable[[Any], List[Any]]):
        super().__init__(values=[])

        self.build_choices = build_choices
        self.dependency_names = list(build_choices.__annotations__.keys())

    def sample(self, dependencies: Dict[str, Any]) -> List[Any]:
        self.validate_deps(dependencies)
        self.values = self.build_choices(**dependencies)
        return super().sample()

    def validate_deps(self, dependencies: Dict[str, Any]) -> bool:
        for name in self.dependency_names:
            assert (
                name in dependencies.keys()
            ), f"Dependency {name} is missing from provided dependencies {dependencies}"


class FloatIntervalSampler(ParamSampler):
    def __init__(self, min_val: float, max_val: float):
        super().__init__()
        self.min_val = min_val
        self.max_val = max_val

    def sample(self) -> float:
        return random.uniform(self.min_val, self.max_val)


class IntIntervalSampler(ParamSampler):
    def __init__(self, min_val: int, max_val: int):
        super().__init__()
        self.min_val = min_val
        self.max_val = max_val

    def sample(self) -> int:
        return random.randint(self.min_val, self.max_val)


def scatter_correction_by_params(params: Dict[str, str]):
    return params["scatter_correction"] == "True"


def normalization_by_params(params: Dict[str, str]):
    _norm = params["normalization"]
    if "Gaussian" in _norm:
        return GaussianNormTransform()
    elif "Max" in _norm:
        return MaxNormTransform()
    elif "Mean" in _norm:
        return MeanNormTransform()
    else:
        raise ValueError(f"Could not find normalization for param {_norm}")


class RandomSearch:
    def __init__(self, param_sampler=Dict[str, ParamSampler]):
        self.param_sampler = param_sampler

    def sample(self):
        _params = {}
        for key, sampler in self.param_sampler.items():
            if isinstance(sampler, DependentChoiceSampler):
                param = sampler.sample(_params)
            else:
                param = sampler.sample()
            _params[key] = param
        return _params

    def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None):
        _params = {}
        for key, value in params.items():
            _params[key] = str(value).replace("\n", "").replace("    ", ", ")

        _str = json.dumps(_params, indent=2)
        if filename is not None:
            with open(filename, mode="w") as f:
                f.write(_str)
        return _str


def validate_and_make_directory(_dir: str):
    if not os.path.exists(_dir):
        os.mkdir(_dir)
        return

    if len(os.listdir(_dir)) > 0:
        print(f"Directory {_dir} exists and is unexpectedly not empty!")
        exit(1)


class RandomSearchCGAN(RandomSearch):
    def __init__(self, iterations: int, logger=None):
        super().__init__({})

        self.dataset_dir = "data/second"
        self.iterations = iterations
        self.dir = "cgan_random_search"
        validate_and_make_directory(self.dir)
        self.device = torch.device("cuda")
        self.params = {}

        self.dir_train = os.path.join(self.dir, "train_data")
        self.logger = (
            logger
            if logger is not None
            else get_logger(
                logfile=os.path.join(self.dir, "search.log"),
                loglevel="INFO",
                name=RandomSearchCGAN.__name__,
            )
        )
        self.training: cGANTraining = None

        # dataset params
        self.param_sampler["patch_size"] = ChoiceSampler([32])
        self.param_sampler["patch_offset"] = ChoiceSampler([0])
        self.param_sampler["patch_number"] = ChoiceSampler([100])
        # self.param_sampler["scatter_correction"] = ChoiceSampler([True, False])
        self.param_sampler["scatter_correction"] = ChoiceSampler([True])
        self.param_sampler["shuffle"] = ChoiceSampler([False, True])
        self.param_sampler["normalization"] = ChoiceSampler(
            [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
        )
        self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=32)])

        # training params
        self.param_sampler["epochs"] = ChoiceSampler([100])
        self.param_sampler["batch_size"] = ChoiceSampler([64])
        # self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
        self.param_sampler["lr"] = ChoiceSampler([0.001])
        self.param_sampler["lr_decay"] = ChoiceSampler([False])
        self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
        self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
        self.param_sampler["criterion_dist"] = ChoiceSampler(
            [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
        )
        # self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0)
        self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
        # self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0)
        self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])

    def run(self):
        nmae_min = sys.maxsize
        for i in range(1, self.iterations + 1):
            self.logger.info(f"Train iteration {i}")

            seed = random.randint(0, 2**32 - 1)
            random.seed(seed)
            self.logger.info(f"Random seed for iteration {i} is {seed}")

            self._setup_run(i)
            self.training.run()

            nmae = self.eval_run()
            self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
            if nmae < nmae_min:
                self.logger.info(f"New best run at iteration {i}")
                nmae_min = nmae
            self._cleanup_run(i, link_best=(nmae_min == nmae))
        return nmae_min

    def eval_run(self):
        self.logger.debug("Perform evaluation ...")
        torch.set_grad_enabled(False)

        weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth")
        self.logger.debug(f"Load weights from {weights_file}")
        model = self.training.params_g.model.eval()
        model.load_state_dict(torch.load(weights_file, map_location=self.device))

        transform_normalization = SequenceTransform(
            [self.params["normalization"], PadCropTranform(dim=3, size=32)]
        )
        dataset = MuMapDataset(
            self.dataset_dir,
            split_name="validation",
            transform_normalization=transform_normalization,
            scatter_correction=self.params["scatter_correction"],
        )

        measures = {"NMAE": nmae, "MSE": mse}
        values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
        for i, (recon, mu_map) in enumerate(dataset):
            print(
                f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}",
                end="\r",
            )
            prediction = model(recon.unsqueeze(dim=0).to(self.device))
            prediction = prediction.squeeze().cpu().numpy()
            mu_map = mu_map.squeeze().cpu().numpy()

            row = pd.DataFrame(
                dict(
                    map(
                        lambda item: (item[0], [item[1](prediction, mu_map)]),
                        measures.items(),
                    )
                )
            )
            values = pd.concat((values, row), ignore_index=True)
        print(f" " * 100, end="\r")

        values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False)
        return values["NMAE"].mean()

    def _setup_run(self, iteration: int):
        self.logger.debug("Create directories...")
        validate_and_make_directory(self.dir_train)
        snapshot_dir = os.path.join(self.dir_train, "snapshots")
        validate_and_make_directory(snapshot_dir)

        self.params = self.sample()
        params_file = os.path.join(self.dir_train, "params.json")
        self.serialize_params(self.params, params_file)
        self.logger.debug(f"Store params at {params_file}")

        logfile = os.path.join(self.dir_train, "train.log")
        logger = get_logger(logfile, loglevel="INFO", name=cGANTraining.__name__)
        self.logger.debug(f"Training logs to {logfile}")

        transforms = [self.params["normalization"], self.params["pad_crop"]]
        transforms = list(filter(lambda transform: transform is not None, transforms))
        transform_normalization = SequenceTransform(transforms)

        self.logger.debug(f"Init data loaders ....")
        data_loaders = {}
        for split in ["train", "validation"]:
            dataset = MuMapPatchDataset(
                self.dataset_dir,
                patches_per_image=self.params["patch_number"],
                patch_size=self.params["patch_size"],
                patch_offset=self.params["patch_offset"],
                shuffle=self.params["shuffle"] if split == "train" else False,
                split_name=split,
                transform_normalization=transform_normalization,
                scatter_correction=self.params["scatter_correction"],
                logger=logger,
            )
            data_loader = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=self.params["batch_size"],
                shuffle=split == "train",
                pin_memory=True,
                num_workers=1,
            )
            data_loaders[split] = data_loader

        self.logger.debug(f"Init discriminator ....")
        discriminator = Discriminator(
            in_channels=2, input_size=self.params["patch_size"]
        )
        discriminator = discriminator.to(self.device)
        optimizer = torch.optim.Adam(
            discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
        )
        lr_scheduler = (
            torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=self.params["lr_decay_epoch"],
                gamma=self.params["lr_decay_factor"],
            )
            if self.params["lr_decay"]
            else None
        )
        params_d = TrainingParams(
            model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
        )

        self.logger.debug(f"Init generator ....")
        features = [64, 128, 256, 512]
        generator = UNet(in_channels=1, features=features)
        generator = generator.to(self.device)
        optimizer = torch.optim.Adam(
            generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
        )
        lr_scheduler = (
            torch.optim.lr_scheduler.StepLR(
                optimizer,
                step_size=self.params["lr_decay_epoch"],
                gamma=self.params["lr_decay_factor"],
            )
            if self.params["lr_decay"]
            else None
        )
        params_g = TrainingParams(
            model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
        )

        self.logger.debug(f"Init training ....")
        self.training = cGANTraining(
            data_loaders=data_loaders,
            epochs=self.params["epochs"],
            device=self.device,
            snapshot_dir=snapshot_dir,
            snapshot_epoch=self.params["epochs"],
            logger=logger,
            params_generator=params_g,
            params_discriminator=params_d,
            loss_func_dist=self.params["criterion_dist"],
            weight_criterion_dist=self.params["weight_criterion_dist"],
            weight_criterion_adv=self.params["weight_criterion_adv"],
        )

    def _cleanup_run(self, iteration: int, link_best: bool):
        dir_from = self.dir_train
        _dir = f"{iteration:0{len(str(self.iterations))}d}"
        dir_to = os.path.join(self.dir, _dir)
        self.logger.debug(f"Move iteration {iteration} from {dir_from} to {dir_to}")
        shutil.move(dir_from, dir_to)

        if link_best:
            linkfile = os.path.join(self.dir, "best")
            if os.path.exists(linkfile):
                os.unlink(linkfile)
            os.symlink(_dir, linkfile)


if __name__ == "__main__":
    random_search = RandomSearchCGAN(iterations=6)
    random_search.run()