Skip to content
Snippets Groups Projects
Commit cdc7ac16 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add class for random search of best parameters

parent 311121fb
No related branches found
No related tags found
No related merge requests found
import json
import os
import random
import shutil
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
import pandas as pd
from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
from mu_map.eval.measures import nmae, mse
from mu_map.models.discriminator import Discriminator
from mu_map.models.unet import UNet
from mu_map.training.cgan import cGANTraining, TrainingParams
from mu_map.training.loss import WeightedLoss
from mu_map.logging import get_logger
class ParamSampler:
def sample(self) -> Any:
pass
class ChoiceSampler(ParamSampler):
def __init__(self, values: List[Any]):
super().__init__()
self.values = values
def sample(self) -> Any:
idx = random.randrange(0, len(self.values))
return self.values[idx]
class DependentChoiceSampler(ChoiceSampler):
def __init__(self, build_choices: Callable[[Any], List[Any]]):
super().__init__(values=[])
self.build_choices = build_choices
self.dependency_names = list(build_choices.__annotations__.keys())
def sample(self, dependencies: Dict[str, Any]) -> List[Any]:
self.validate_deps(dependencies)
self.values = self.build_choices(**dependencies)
return super().sample()
def validate_deps(self, dependencies: Dict[str, Any]) -> bool:
for name in self.dependency_names:
assert (
name in dependencies.keys()
), f"Dependency {name} is missing from provided dependencies {dependencies}"
class FloatIntervalSampler(ParamSampler):
def __init__(self, min_val: float, max_val: float):
super().__init__()
self.min_val = min_val
self.max_val = max_val
def sample(self) -> float:
return random.uniform(self.min_val, self.max_val)
class IntIntervalSampler(ParamSampler):
def __init__(self, min_val: int, max_val: int):
super().__init__()
self.min_val = min_val
self.max_val = max_val
def sample(self) -> int:
return random.randint(self.min_val, self.max_val)
class RandomSearch:
def __init__(self, param_sampler=Dict[str, ParamSampler]):
self.param_sampler = param_sampler
def sample(self):
_params = {}
for key, sampler in self.param_sampler.items():
if isinstance(sampler, DependentChoiceSampler):
param = sampler.sample(_params)
else:
param = sampler.sample()
_params[key] = param
return _params
def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None):
_params = {}
for key, value in params.items():
_params[key] = str(value).replace("\n", "").replace(" ", ", ")
_str = json.dumps(_params, indent=2)
if filename is not None:
with open(filename, mode="w") as f:
f.write(_str)
return _str
def validate_and_make_directory(_dir: str):
if not os.path.exists(_dir):
os.mkdir(_dir)
return
if len(os.listdir(_dir)) > 0:
print(f"Directory {_dir} exists and is unexpectedly not empty!")
exit(1)
class RandomSearchCGAN(RandomSearch):
def __init__(self, iterations: int, logger=None):
super().__init__({})
self.dataset_dir = "data/second"
self.iterations = iterations
self.dir = "cgan_random_search"
validate_and_make_directory(self.dir)
self.device = torch.device("cuda")
self.params = {}
self.dir_train = os.path.join(self.dir, "train_data")
self.logger = (
logger
if logger is not None
else get_logger(
logfile=os.path.join(self.dir, "search.log"),
loglevel="INFO",
name=RandomSearchCGAN.__name__,
)
)
self.training: cGANTraining = None
# dataset params
self.param_sampler["patch_size"] = ChoiceSampler([32])
self.param_sampler["patch_offset"] = IntIntervalSampler(0, 32)
self.param_sampler["patch_number"] = IntIntervalSampler(50, 200)
self.param_sampler["scatter_correction"] = ChoiceSampler([True, False])
self.param_sampler["shuffle"] = ChoiceSampler([True, False])
self.param_sampler["normalization"] = ChoiceSampler(
[MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
)
self.param_sampler["pad_crop"] = ChoiceSampler(
[None, PadCropTranform(dim=3, size=32)]
)
# training params
self.param_sampler["epochs"] = ChoiceSampler([100])
self.param_sampler["batch_size"] = ChoiceSampler([64])
self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
self.param_sampler["lr_decay"] = ChoiceSampler([True, False])
self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
self.param_sampler["criterion_dist"] = ChoiceSampler(
[WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
)
self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0)
self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0)
def run(self):
nmae_min = sys.maxsize
for i in range(1, self.iterations + 1):
self.logger.info(f"Train iteration {i}")
seed = random.randint(0, 2**32 - 1)
random.seed(seed)
self.logger.info(f"Random seed for iteration {i} is {seed}")
self._setup_run(i)
self.training.run()
nmae = self.eval_run()
self.logger.info(f"Iteration {i} has NMAE {nmae:.6f}")
if nmae < nmae_min:
self.logger.info(f"New best run at iteration {i}")
nmae_min = nmae
self._cleanup_run(i, link_best=(nmae_min == nmae))
return nmae_min
def eval_run(self):
self.logger.debug("Perform evaluation ...")
torch.set_grad_enabled(False)
weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth")
self.logger.debug(f"Load weights from {weights_file}")
model = self.training.params_g.model.eval()
model.load_state_dict(torch.load(weights_file, map_location=self.device))
transform_normalization = SequenceTransform(
[self.params["normalization"], PadCropTranform(dim=3, size=32)]
)
dataset = MuMapDataset(
self.dataset_dir,
split_name="validation",
transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"],
)
measures = {"NMAE": nmae, "MSE": mse}
values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
for i, (recon, mu_map) in enumerate(dataset):
print(
f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}",
end="\r",
)
prediction = model(recon.unsqueeze(dim=0).to(self.device))
prediction = prediction.squeeze().cpu().numpy()
mu_map = mu_map.squeeze().cpu().numpy()
row = pd.DataFrame(
dict(
map(
lambda item: (item[0], [item[1](prediction, mu_map)]),
measures.items(),
)
)
)
values = pd.concat((values, row), ignore_index=True)
print(f" " * 100, end="\r")
values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False)
return values["NMAE"].mean()
def _setup_run(self, iteration: int):
self.logger.debug("Create directories...")
validate_and_make_directory(self.dir_train)
snapshot_dir = os.path.join(self.dir_train, "snapshots")
validate_and_make_directory(snapshot_dir)
self.params = self.sample()
params_file = os.path.join(self.dir_train, "params.json")
self.serialize_params(self.params, params_file)
self.logger.debug(f"Store params at {params_file}")
logfile = os.path.join(self.dir_train, "train.log")
logger = get_logger(logfile, loglevel="INFO", name=cGANTraining.__name__)
self.logger.debug(f"Training logs to {logfile}")
transforms = [self.params["normalization"], self.params["pad_crop"]]
transforms = list(filter(lambda transform: transform is not None, transforms))
transform_normalization = SequenceTransform(transforms)
self.logger.debug(f"Init data loaders ....")
data_loaders = {}
for split in ["train", "validation"]:
dataset = MuMapPatchDataset(
self.dataset_dir,
patches_per_image=self.params["patch_number"],
patch_size=self.params["patch_size"],
patch_offset=self.params["patch_offset"],
shuffle=self.params["shuffle"] if split == "train" else False,
split_name=split,
transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"],
logger=logger,
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.params["batch_size"],
shuffle=split == "train",
pin_memory=True,
num_workers=1,
)
data_loaders[split] = data_loader
self.logger.debug(f"Init discriminator ....")
discriminator = Discriminator(
in_channels=2, input_size=self.params["patch_size"]
)
discriminator = discriminator.to(self.device)
optimizer = torch.optim.Adam(
discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
)
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.params["lr_decay_epoch"],
gamma=self.params["lr_decay_factor"],
)
if self.params["lr_decay"]
else None
)
params_d = TrainingParams(
model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
self.logger.debug(f"Init generator ....")
features = [64, 128, 256, 512]
generator = UNet(in_channels=1, features=features)
generator = generator.to(self.device)
optimizer = torch.optim.Adam(
generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
)
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer,
step_size=self.params["lr_decay_epoch"],
gamma=self.params["lr_decay_factor"],
)
if self.params["lr_decay"]
else None
)
params_g = TrainingParams(
model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
self.logger.debug(f"Init training ....")
self.training = cGANTraining(
data_loaders=data_loaders,
epochs=self.params["epochs"],
device=self.device,
snapshot_dir=snapshot_dir,
snapshot_epoch=self.params["epochs"],
logger=logger,
params_generator=params_g,
params_discriminator=params_d,
loss_func_dist=self.params["criterion_dist"],
weight_criterion_dist=self.params["weight_criterion_dist"],
weight_criterion_adv=self.params["weight_criterion_adv"],
)
def _cleanup_run(self, iteration: int, link_best: bool):
dir_from = self.dir_train
_dir = f"{iteration:0{len(str(self.iterations))}d}"
dir_to = os.path.join(self.dir, _dir)
self.logger.debug(f"Move iteration {iteration} from {dir_from} to {dir_to}")
shutil.move(dir_from, dir_to)
if link_best:
linkfile = os.path.join(self.dir, "best")
if os.path.exists(linkfile):
os.unlink(linkfile)
os.symlink(_dir, linkfile)
if __name__ == "__main__":
random_search = RandomSearchCGAN(iterations=20)
random_search.run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment