-
Tamino Huxohl authoredTamino Huxohl authored
random_search.py 13.05 KiB
import json
import os
import random
import shutil
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
import pandas as pd
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
from mu_map.eval.measures import nmae, mse
from mu_map.models.discriminator import Discriminator
from mu_map.models.unet import UNet
from mu_map.training.cgan import cGANTraining, TrainingParams
from mu_map.training.loss import WeightedLoss
from mu_map.logging import get_logger
class ParamSampler:
def sample(self) -> Any:
pass
class ChoiceSampler(ParamSampler):
def __init__(self, values: List[Any]):
super().__init__()
self.values = values
def sample(self) -> Any:
idx = random.randrange(0, len(self.values))
return self.values[idx]
class DependentChoiceSampler(ChoiceSampler):
def __init__(self, build_choices: Callable[[Any], List[Any]]):
super().__init__(values=[])
self.build_choices = build_choices
self.dependency_names = list(build_choices.__annotations__.keys())
def sample(self, dependencies: Dict[str, Any]) -> List[Any]:
self.validate_deps(dependencies)
self.values = self.build_choices(**dependencies)
return super().sample()
def validate_deps(self, dependencies: Dict[str, Any]) -> bool:
for name in self.dependency_names:
assert (
name in dependencies.keys()
), f"Dependency {name} is missing from provided dependencies {dependencies}"
class FloatIntervalSampler(ParamSampler):
def __init__(self, min_val: float, max_val: float):
super().__init__()
self.min_val = min_val
self.max_val = max_val
def sample(self) -> float:
return random.uniform(self.min_val, self.max_val)
class IntIntervalSampler(ParamSampler):
def __init__(self, min_val: int, max_val: int):
super().__init__()
self.min_val = min_val
self.max_val = max_val
def sample(self) -> int:
return random.randint(self.min_val, self.max_val)
def scatter_correction_by_params(params: Dict[str, str]):
return params["scatter_correction"] == "True"
def normalization_by_params(params: Dict[str, str]):
_norm = params["normalization"]
if "Gaussian" in _norm:
return GaussianNormTransform()
elif "Max" in _norm:
return MaxNormTransform()
elif "Mean" in _norm:
return MeanNormTransform()
else:
raise ValueError(f"Could not find normalization for param {_norm}")
class RandomSearch:
def __init__(self, param_sampler=Dict[str, ParamSampler]):
self.param_sampler = param_sampler
def sample(self):
_params = {}
for key, sampler in self.param_sampler.items():
if isinstance(sampler, DependentChoiceSampler):
param = sampler.sample(_params)
else:
param = sampler.sample()
_params[key] = param
return _params
def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None):
_params = {}
for key, value in params.items():
_params[key] = str(value).replace("\n", "").replace(" ", ", ")
_str = json.dumps(_params, indent=2)
if filename is not None:
with open(filename, mode="w") as f:
f.write(_str)
return _str
def validate_and_make_directory(_dir: str):
if not os.path.exists(_dir):
os.mkdir(_dir)
return
if len(os.listdir(_dir)) > 0:
print(f"Directory {_dir} exists and is unexpectedly not empty!")
exit(1)
class RandomSearchCGAN(RandomSearch):
def __init__(self, iterations: int, logger=None):
super().__init__({})
self.dataset_dir = "data/second"
self.iterations = iterations
self.dir = "cgan_random_search"
validate_and_make_directory(self.dir)
self.device = torch.device("cuda")
self.params = {}
self.dir_train = os.path.join(self.dir, "train_data")
self.logger = (
logger
if logger is not None
else get_logger(
logfile=os.path.join(self.dir, "search.log"),
loglevel="INFO",
name=RandomSearchCGAN.__name__,
)
)
self.training: cGANTraining = None
# dataset params
self.param_sampler["patch_size"] = ChoiceSampler([32])
self.param_sampler["patch_offset"] = ChoiceSampler([0])
self.param_sampler["patch_number"] = ChoiceSampler([100])
# self.param_sampler["scatter_correction"] = ChoiceSampler([True, False])
self.param_sampler["scatter_correction"] = ChoiceSampler([True])
self.param_sampler["shuffle"] = ChoiceSampler([False, True])
self.param_sampler["normalization"] = ChoiceSampler(
[MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
)
self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=32)])
# training params
self.param_sampler["epochs"] = ChoiceSampler([100])
self.param_sampler["batch_size"] = ChoiceSampler([64])
# self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
self.param_sampler["lr"] = ChoiceSampler([0.001])
self.param_sampler["lr_decay"] = ChoiceSampler([False])
self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
self.param_sampler["criterion_dist"] = ChoiceSampler(
[WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
)
# self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0)
self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
# self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0)
self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])
def run(self):
nmae_min = sys.maxsize
for i in range(1, self.iterations + 1):
self.logger.info(f"Train iteration {i}")
seed = random.randint(0, 2**32 - 1)
random.seed(seed)
self.logger.info(f"Random seed for iteration {i} is {seed}")
self._setup_run(i)
self.training.run()
nmae = self.eval_run()
self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
if nmae < nmae_min:
self.logger.info(f"New best run at iteration {i}")
nmae_min = nmae
self._cleanup_run(i, link_best=(nmae_min == nmae))
return nmae_min
def eval_run(self):
self.logger.debug("Perform evaluation ...")
torch.set_grad_enabled(False)
weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth")
self.logger.debug(f"Load weights from {weights_file}")
model = self.training.params_g.model.eval()
model.load_state_dict(torch.load(weights_file, map_location=self.device))
transform_normalization = SequenceTransform(
[self.params["normalization"], PadCropTranform(dim=3, size=32)]
)
dataset = MuMapDataset(
self.dataset_dir,
split_name="validation",
transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"],
)
measures = {"NMAE": nmae, "MSE": mse}
values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
for i, (recon, mu_map) in enumerate(dataset):
print(
f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}",
end="\r",
)
prediction = model(recon.unsqueeze(dim=0).to(self.device))
prediction = prediction.squeeze().cpu().numpy()
mu_map = mu_map.squeeze().cpu().numpy()
row = pd.DataFrame(
dict(
map(
lambda item: (item[0], [item[1](prediction, mu_map)]),
measures.items(),
)
)
)
values = pd.concat((values, row), ignore_index=True)
print(f" " * 100, end="\r")
values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False)
return values["NMAE"].mean()
def _setup_run(self, iteration: int):
self.logger.debug("Create directories...")
validate_and_make_directory(self.dir_train)
snapshot_dir = os.path.join(self.dir_train, "snapshots")
validate_and_make_directory(snapshot_dir)
self.params = self.sample()
params_file = os.path.join(self.dir_train, "params.json")
self.serialize_params(self.params, params_file)
self.logger.debug(f"Store params at {params_file}")
logfile = os.path.join(self.dir_train, "train.log")
logger = get_logger(logfile, loglevel="INFO", name=cGANTraining.__name__)
self.logger.debug(f"Training logs to {logfile}")
transforms = [self.params["normalization"], self.params["pad_crop"]]
transforms = list(filter(lambda transform: transform is not None, transforms))
transform_normalization = SequenceTransform(transforms)
self.logger.debug(f"Init data loaders ....")
data_loaders = {}
for split in ["train", "validation"]:
dataset = MuMapPatchDataset(
self.dataset_dir,
patches_per_image=self.params["patch_number"],
patch_size=self.params["patch_size"],
patch_offset=self.params["patch_offset"],
shuffle=self.params["shuffle"] if split == "train" else False,
split_name=split,
transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"],
logger=logger,
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.params["batch_size"],
shuffle=split == "train",
pin_memory=True,
num_workers=1,
)
data_loaders[split] = data_loader
self.logger.debug(f"Init discriminator ....")
discriminator = Discriminator(
in_channels=2, input_size=self.params["patch_size"]
)
discriminator = discriminator.to(self.device)
optimizer = torch.optim.Adam(
discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
)
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.params["lr_decay_epoch"],
gamma=self.params["lr_decay_factor"],
)
if self.params["lr_decay"]
else None
)
params_d = TrainingParams(
model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
self.logger.debug(f"Init generator ....")
features = [64, 128, 256, 512]
generator = UNet(in_channels=1, features=features)
generator = generator.to(self.device)
optimizer = torch.optim.Adam(
generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
)
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.params["lr_decay_epoch"],
gamma=self.params["lr_decay_factor"],
)
if self.params["lr_decay"]
else None
)
params_g = TrainingParams(
model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
self.logger.debug(f"Init training ....")
self.training = cGANTraining(
data_loaders=data_loaders,
epochs=self.params["epochs"],
device=self.device,
snapshot_dir=snapshot_dir,
snapshot_epoch=self.params["epochs"],
logger=logger,
params_generator=params_g,
params_discriminator=params_d,
loss_func_dist=self.params["criterion_dist"],
weight_criterion_dist=self.params["weight_criterion_dist"],
weight_criterion_adv=self.params["weight_criterion_adv"],
)
def _cleanup_run(self, iteration: int, link_best: bool):
dir_from = self.dir_train
_dir = f"{iteration:0{len(str(self.iterations))}d}"
dir_to = os.path.join(self.dir, _dir)
self.logger.debug(f"Move iteration {iteration} from {dir_from} to {dir_to}")
shutil.move(dir_from, dir_to)
if link_best:
linkfile = os.path.join(self.dir, "best")
if os.path.exists(linkfile):
os.unlink(linkfile)
os.symlink(_dir, linkfile)
if __name__ == "__main__":
random_search = RandomSearchCGAN(iterations=6)
random_search.run()