import json import os import random import shutil import sys from typing import Any, Callable, Dict, List, Optional import torch import pandas as pd from mu_map.dataset.default import MuMapDataset from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.normalization import ( MeanNormTransform, MaxNormTransform, GaussianNormTransform, ) from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform from mu_map.eval.measures import nmae, mse from mu_map.models.discriminator import Discriminator from mu_map.models.unet import UNet from mu_map.training.cgan import cGANTraining, TrainingParams from mu_map.training.loss import WeightedLoss from mu_map.logging import get_logger class ParamSampler: def sample(self) -> Any: pass class ChoiceSampler(ParamSampler): def __init__(self, values: List[Any]): super().__init__() self.values = values def sample(self) -> Any: idx = random.randrange(0, len(self.values)) return self.values[idx] class DependentChoiceSampler(ChoiceSampler): def __init__(self, build_choices: Callable[[Any], List[Any]]): super().__init__(values=[]) self.build_choices = build_choices self.dependency_names = list(build_choices.__annotations__.keys()) def sample(self, dependencies: Dict[str, Any]) -> List[Any]: self.validate_deps(dependencies) self.values = self.build_choices(**dependencies) return super().sample() def validate_deps(self, dependencies: Dict[str, Any]) -> bool: for name in self.dependency_names: assert ( name in dependencies.keys() ), f"Dependency {name} is missing from provided dependencies {dependencies}" class FloatIntervalSampler(ParamSampler): def __init__(self, min_val: float, max_val: float): super().__init__() self.min_val = min_val self.max_val = max_val def sample(self) -> float: return random.uniform(self.min_val, self.max_val) class IntIntervalSampler(ParamSampler): def __init__(self, min_val: int, max_val: int): super().__init__() self.min_val = min_val self.max_val = max_val def sample(self) -> int: return random.randint(self.min_val, self.max_val) def scatter_correction_by_params(params: Dict[str, str]): return params["scatter_correction"] == "True" def normalization_by_params(params: Dict[str, str]): _norm = params["normalization"] if "Gaussian" in _norm: return GaussianNormTransform() elif "Max" in _norm: return MaxNormTransform() elif "Mean" in _norm: return MeanNormTransform() else: raise ValueError(f"Could not find normalization for param {_norm}") class RandomSearch: def __init__(self, param_sampler=Dict[str, ParamSampler]): self.param_sampler = param_sampler def sample(self): _params = {} for key, sampler in self.param_sampler.items(): if isinstance(sampler, DependentChoiceSampler): param = sampler.sample(_params) else: param = sampler.sample() _params[key] = param return _params def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None): _params = {} for key, value in params.items(): _params[key] = str(value).replace("\n", "").replace(" ", ", ") _str = json.dumps(_params, indent=2) if filename is not None: with open(filename, mode="w") as f: f.write(_str) return _str def validate_and_make_directory(_dir: str): if not os.path.exists(_dir): os.mkdir(_dir) return if len(os.listdir(_dir)) > 0: print(f"Directory {_dir} exists and is unexpectedly not empty!") exit(1) class RandomSearchCGAN(RandomSearch): def __init__(self, iterations: int, logger=None): super().__init__({}) self.dataset_dir = "data/second" self.iterations = iterations self.dir = "cgan_random_search" validate_and_make_directory(self.dir) self.device = torch.device("cuda") self.params = {} self.dir_train = os.path.join(self.dir, "train_data") self.logger = ( logger if logger is not None else get_logger( logfile=os.path.join(self.dir, "search.log"), loglevel="INFO", name=RandomSearchCGAN.__name__, ) ) self.training: cGANTraining = None # dataset params self.param_sampler["patch_size"] = ChoiceSampler([32]) self.param_sampler["patch_offset"] = ChoiceSampler([0]) self.param_sampler["patch_number"] = ChoiceSampler([100]) # self.param_sampler["scatter_correction"] = ChoiceSampler([True, False]) self.param_sampler["scatter_correction"] = ChoiceSampler([True]) self.param_sampler["shuffle"] = ChoiceSampler([False, True]) self.param_sampler["normalization"] = ChoiceSampler( [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()] ) self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=32)]) # training params self.param_sampler["epochs"] = ChoiceSampler([100]) self.param_sampler["batch_size"] = ChoiceSampler([64]) # self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001) self.param_sampler["lr"] = ChoiceSampler([0.001]) self.param_sampler["lr_decay"] = ChoiceSampler([False]) self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1]) self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99]) self.param_sampler["criterion_dist"] = ChoiceSampler( [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")] ) # self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0) self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0]) # self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0) self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0]) def run(self): nmae_min = sys.maxsize for i in range(1, self.iterations + 1): self.logger.info(f"Train iteration {i}") seed = random.randint(0, 2**32 - 1) random.seed(seed) self.logger.info(f"Random seed for iteration {i} is {seed}") self._setup_run(i) self.training.run() nmae = self.eval_run() self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}") if nmae < nmae_min: self.logger.info(f"New best run at iteration {i}") nmae_min = nmae self._cleanup_run(i, link_best=(nmae_min == nmae)) return nmae_min def eval_run(self): self.logger.debug("Perform evaluation ...") torch.set_grad_enabled(False) weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth") self.logger.debug(f"Load weights from {weights_file}") model = self.training.params_g.model.eval() model.load_state_dict(torch.load(weights_file, map_location=self.device)) transform_normalization = SequenceTransform( [self.params["normalization"], PadCropTranform(dim=3, size=32)] ) dataset = MuMapDataset( self.dataset_dir, split_name="validation", transform_normalization=transform_normalization, scatter_correction=self.params["scatter_correction"], ) measures = {"NMAE": nmae, "MSE": mse} values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys()))) for i, (recon, mu_map) in enumerate(dataset): print( f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", end="\r", ) prediction = model(recon.unsqueeze(dim=0).to(self.device)) prediction = prediction.squeeze().cpu().numpy() mu_map = mu_map.squeeze().cpu().numpy() row = pd.DataFrame( dict( map( lambda item: (item[0], [item[1](prediction, mu_map)]), measures.items(), ) ) ) values = pd.concat((values, row), ignore_index=True) print(f" " * 100, end="\r") values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False) return values["NMAE"].mean() def _setup_run(self, iteration: int): self.logger.debug("Create directories...") validate_and_make_directory(self.dir_train) snapshot_dir = os.path.join(self.dir_train, "snapshots") validate_and_make_directory(snapshot_dir) self.params = self.sample() params_file = os.path.join(self.dir_train, "params.json") self.serialize_params(self.params, params_file) self.logger.debug(f"Store params at {params_file}") logfile = os.path.join(self.dir_train, "train.log") logger = get_logger(logfile, loglevel="INFO", name=cGANTraining.__name__) self.logger.debug(f"Training logs to {logfile}") transforms = [self.params["normalization"], self.params["pad_crop"]] transforms = list(filter(lambda transform: transform is not None, transforms)) transform_normalization = SequenceTransform(transforms) self.logger.debug(f"Init data loaders ....") data_loaders = {} for split in ["train", "validation"]: dataset = MuMapPatchDataset( self.dataset_dir, patches_per_image=self.params["patch_number"], patch_size=self.params["patch_size"], patch_offset=self.params["patch_offset"], shuffle=self.params["shuffle"] if split == "train" else False, split_name=split, transform_normalization=transform_normalization, scatter_correction=self.params["scatter_correction"], logger=logger, ) data_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=self.params["batch_size"], shuffle=split == "train", pin_memory=True, num_workers=1, ) data_loaders[split] = data_loader self.logger.debug(f"Init discriminator ....") discriminator = Discriminator( in_channels=2, input_size=self.params["patch_size"] ) discriminator = discriminator.to(self.device) optimizer = torch.optim.Adam( discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999) ) lr_scheduler = ( torch.optim.lr_scheduler.StepLR( optimizer, step_size=self.params["lr_decay_epoch"], gamma=self.params["lr_decay_factor"], ) if self.params["lr_decay"] else None ) params_d = TrainingParams( model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler ) self.logger.debug(f"Init generator ....") features = [64, 128, 256, 512] generator = UNet(in_channels=1, features=features) generator = generator.to(self.device) optimizer = torch.optim.Adam( generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999) ) lr_scheduler = ( torch.optim.lr_scheduler.StepLR( optimizer, step_size=self.params["lr_decay_epoch"], gamma=self.params["lr_decay_factor"], ) if self.params["lr_decay"] else None ) params_g = TrainingParams( model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler ) self.logger.debug(f"Init training ....") self.training = cGANTraining( data_loaders=data_loaders, epochs=self.params["epochs"], device=self.device, snapshot_dir=snapshot_dir, snapshot_epoch=self.params["epochs"], logger=logger, params_generator=params_g, params_discriminator=params_d, loss_func_dist=self.params["criterion_dist"], weight_criterion_dist=self.params["weight_criterion_dist"], weight_criterion_adv=self.params["weight_criterion_adv"], ) def _cleanup_run(self, iteration: int, link_best: bool): dir_from = self.dir_train _dir = f"{iteration:0{len(str(self.iterations))}d}" dir_to = os.path.join(self.dir, _dir) self.logger.debug(f"Move iteration {iteration} from {dir_from} to {dir_to}") shutil.move(dir_from, dir_to) if link_best: linkfile = os.path.join(self.dir, "best") if os.path.exists(linkfile): os.unlink(linkfile) os.symlink(_dir, linkfile) if __name__ == "__main__": random_search = RandomSearchCGAN(iterations=6) random_search.run()