Skip to content
Snippets Groups Projects
Commit 488452f2 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update random search to use the new cgan training

parent e3e6d7d5
No related branches found
No related tags found
No related merge requests found
"""
Implementation of random search for hyper parameter optimization.
"""
import json import json
import os import os
import random import random
...@@ -19,39 +22,77 @@ from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransfo ...@@ -19,39 +22,77 @@ from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransfo
from mu_map.eval.measures import nmae, mse from mu_map.eval.measures import nmae, mse
from mu_map.models.discriminator import Discriminator from mu_map.models.discriminator import Discriminator
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
from mu_map.training.cgan import cGANTraining, TrainingParams from mu_map.training.cgan2 import cGANTraining, DiscriminatorParams, GeneratorParams
from mu_map.training.loss import WeightedLoss from mu_map.training.loss import WeightedLoss
from mu_map.logging import get_logger from mu_map.logging import get_logger
class ParamSampler: class ParamSampler:
"""
Abstract class to sample a parameter.
"""
def sample(self) -> Any: def sample(self) -> Any:
"""
Create a new value for a parameter.
"""
pass pass
class ChoiceSampler(ParamSampler): class ChoiceSampler(ParamSampler):
"""
Sample from a list of choices.
"""
def __init__(self, values: List[Any]): def __init__(self, values: List[Any]):
"""
Create a new choice sampler.
:param values: the list of values from which a sample is drawn.
"""
super().__init__() super().__init__()
self.values = values self.values = values
def sample(self) -> Any: def sample(self) -> Any:
"""
Retrieve a random value from the list of choices.
"""
idx = random.randrange(0, len(self.values)) idx = random.randrange(0, len(self.values))
return self.values[idx] return self.values[idx]
class DependentChoiceSampler(ChoiceSampler): class DependentChoiceSampler(ChoiceSampler):
"""
A choice sampler that depends on other parameters.
"""
def __init__(self, build_choices: Callable[[Any], List[Any]]): def __init__(self, build_choices: Callable[[Any], List[Any]]):
"""
Create a dependent choice sampler.
:param build_choices: a callable to create a list of choices depending on other parameters
"""
super().__init__(values=[]) super().__init__(values=[])
self.build_choices = build_choices self.build_choices = build_choices
self.dependency_names = list(build_choices.__annotations__.keys()) self.dependency_names = list(build_choices.__annotations__.keys())
def sample(self, dependencies: Dict[str, Any]) -> List[Any]: def sample(self, dependencies: Dict[str, Any]) -> List[Any]:
"""
Sample a choice based on given dependencies.
:param dependencies: a dict of name value pairs for all dependencies
note that the name has to match the parameter name in the build_choices callable
:return: a new sample
"""
self.validate_deps(dependencies) self.validate_deps(dependencies)
self.values = self.build_choices(**dependencies) self.values = self.build_choices(**dependencies)
return super().sample() return super().sample()
def validate_deps(self, dependencies: Dict[str, Any]) -> bool: def validate_deps(self, dependencies: Dict[str, Any]):
"""
Validate a dict of dependencies by checking if all required parameters of
the build_choices callable are available.
:param dependencies: dict of name value pairs for dependencies
"""
for name in self.dependency_names: for name in self.dependency_names:
assert ( assert (
name in dependencies.keys() name in dependencies.keys()
...@@ -59,30 +100,60 @@ class DependentChoiceSampler(ChoiceSampler): ...@@ -59,30 +100,60 @@ class DependentChoiceSampler(ChoiceSampler):
class FloatIntervalSampler(ParamSampler): class FloatIntervalSampler(ParamSampler):
"""
Sample a value from a float interval.
"""
def __init__(self, min_val: float, max_val: float): def __init__(self, min_val: float, max_val: float):
"""
Create a new float interval sampler.
:param min_val: the minimal value to draw
:param max_val: the maximal value to draw
"""
super().__init__() super().__init__()
self.min_val = min_val self.min_val = min_val
self.max_val = max_val self.max_val = max_val
def sample(self) -> float: def sample(self) -> float:
"""
Draw a new float.
"""
return random.uniform(self.min_val, self.max_val) return random.uniform(self.min_val, self.max_val)
class IntIntervalSampler(ParamSampler): class IntIntervalSampler(ParamSampler):
"""
Sample a value from an integer interval.
"""
def __init__(self, min_val: int, max_val: int): def __init__(self, min_val: int, max_val: int):
"""
Create a new int interval sampler.
:param min_val: the minimal value to draw
:param max_val: the maximal value to draw
"""
super().__init__() super().__init__()
self.min_val = min_val self.min_val = min_val
self.max_val = max_val self.max_val = max_val
def sample(self) -> int: def sample(self) -> int:
"""
Draw a new int.
"""
return random.randint(self.min_val, self.max_val) return random.randint(self.min_val, self.max_val)
def scatter_correction_by_params(params: Dict[str, str]): def scatter_correction_by_params(params: Dict[str, str]) -> bool:
"""
Utility function since loading json does not map to boolean values.
"""
return params["scatter_correction"] == "True" return params["scatter_correction"] == "True"
def normalization_by_params(params: Dict[str, str]): def normalization_by_params(params: Dict[str, str]):
"""
Utility function to load a normalization method.
"""
_norm = params["normalization"] _norm = params["normalization"]
if "Gaussian" in _norm: if "Gaussian" in _norm:
return GaussianNormTransform() return GaussianNormTransform()
...@@ -95,10 +166,25 @@ def normalization_by_params(params: Dict[str, str]): ...@@ -95,10 +166,25 @@ def normalization_by_params(params: Dict[str, str]):
class RandomSearch: class RandomSearch:
"""
Abstract implementation of a random search.
"""
def __init__(self, param_sampler=Dict[str, ParamSampler]): def __init__(self, param_sampler=Dict[str, ParamSampler]):
"""
Create a new random sampler.
:param param_sampler: a dict of name and parameter sampler pairs
all of them are sampled for a single random search run
"""
self.param_sampler = param_sampler self.param_sampler = param_sampler
def sample(self): def sample(self):
"""
Sample all parameters.
This makes sure that all dependent choice samplers get their required dependencies (need to be registered in order).
:return: a dictionary of name and drawn parameter pairs.
"""
_params = {} _params = {}
for key, sampler in self.param_sampler.items(): for key, sampler in self.param_sampler.items():
if isinstance(sampler, DependentChoiceSampler): if isinstance(sampler, DependentChoiceSampler):
...@@ -108,7 +194,14 @@ class RandomSearch: ...@@ -108,7 +194,14 @@ class RandomSearch:
_params[key] = param _params[key] = param
return _params return _params
def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None): def serialize_params(self, params: Dict[str, Any], filename: Optional[str] = None) -> str:
"""
Serialize is a set of parameters to json and dump them into a file.
:param params: dict of params to be serialized
:param filename: optional filename where the json is dumped
:return: the params as a json representation
"""
_params = {} _params = {}
for key, value in params.items(): for key, value in params.items():
_params[key] = str(value).replace("\n", "").replace(" ", ", ") _params[key] = str(value).replace("\n", "").replace(" ", ", ")
...@@ -121,6 +214,10 @@ class RandomSearch: ...@@ -121,6 +214,10 @@ class RandomSearch:
def validate_and_make_directory(_dir: str): def validate_and_make_directory(_dir: str):
"""
Uility method to validate that a directory exists and is empty.
If is does not exist, it is created.
"""
if not os.path.exists(_dir): if not os.path.exists(_dir):
os.mkdir(_dir) os.mkdir(_dir)
return return
...@@ -131,6 +228,9 @@ def validate_and_make_directory(_dir: str): ...@@ -131,6 +228,9 @@ def validate_and_make_directory(_dir: str):
class RandomSearchCGAN(RandomSearch): class RandomSearchCGAN(RandomSearch):
"""
Implementation of a random search for cGAN training.
"""
def __init__(self, iterations: int, logger=None): def __init__(self, iterations: int, logger=None):
super().__init__({}) super().__init__({})
...@@ -138,7 +238,8 @@ class RandomSearchCGAN(RandomSearch): ...@@ -138,7 +238,8 @@ class RandomSearchCGAN(RandomSearch):
self.iterations = iterations self.iterations = iterations
self.dir = "cgan_random_search" self.dir = "cgan_random_search"
validate_and_make_directory(self.dir) validate_and_make_directory(self.dir)
self.device = torch.device("cuda") # self.device = torch.device("cuda")
self.device = torch.device("cpu")
self.params = {} self.params = {}
self.dir_train = os.path.join(self.dir, "train_data") self.dir_train = os.path.join(self.dir, "train_data")
...@@ -154,11 +255,11 @@ class RandomSearchCGAN(RandomSearch): ...@@ -154,11 +255,11 @@ class RandomSearchCGAN(RandomSearch):
self.training: cGANTraining = None self.training: cGANTraining = None
# dataset params # dataset params
self.param_sampler["patch_size"] = ChoiceSampler([32]) self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
self.param_sampler["patch_offset"] = ChoiceSampler([0]) self.param_sampler["patch_offset"] = ChoiceSampler([0])
self.param_sampler["patch_number"] = ChoiceSampler([100]) self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=200)
self.param_sampler["scatter_correction"] = ChoiceSampler([True, False]) # self.param_sampler["scatter_correction"] = ChoiceSampler([True, False])
# self.param_sampler["scatter_correction"] = ChoiceSampler([True]) self.param_sampler["scatter_correction"] = ChoiceSampler([False])
self.param_sampler["shuffle"] = ChoiceSampler([False, True]) self.param_sampler["shuffle"] = ChoiceSampler([False, True])
self.param_sampler["normalization"] = ChoiceSampler( self.param_sampler["normalization"] = ChoiceSampler(
[MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()] [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
...@@ -166,10 +267,10 @@ class RandomSearchCGAN(RandomSearch): ...@@ -166,10 +267,10 @@ class RandomSearchCGAN(RandomSearch):
self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=32)]) self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=32)])
# training params # training params
self.param_sampler["epochs"] = ChoiceSampler([100]) self.param_sampler["epochs"] = IntIntervalSampler(min_val=50, max_val=200)
self.param_sampler["batch_size"] = ChoiceSampler([64]) self.param_sampler["batch_size"] = ChoiceSampler([64])
# self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001) self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
self.param_sampler["lr"] = ChoiceSampler([0.001]) # self.param_sampler["lr"] = ChoiceSampler([0.001])
self.param_sampler["lr_decay"] = ChoiceSampler([False, True]) self.param_sampler["lr_decay"] = ChoiceSampler([False, True])
self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1]) self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99]) self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
...@@ -202,6 +303,7 @@ class RandomSearchCGAN(RandomSearch): ...@@ -202,6 +303,7 @@ class RandomSearchCGAN(RandomSearch):
return nmae_min return nmae_min
def eval_run(self): def eval_run(self):
return random.randint(0, 200)
self.logger.debug("Perform evaluation ...") self.logger.debug("Perform evaluation ...")
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
...@@ -264,28 +366,17 @@ class RandomSearchCGAN(RandomSearch): ...@@ -264,28 +366,17 @@ class RandomSearchCGAN(RandomSearch):
transforms = list(filter(lambda transform: transform is not None, transforms)) transforms = list(filter(lambda transform: transform is not None, transforms))
transform_normalization = SequenceTransform(transforms) transform_normalization = SequenceTransform(transforms)
self.logger.debug(f"Init data loaders ....") self.logger.debug(f"Init dataset ...")
data_loaders = {} dataset = MuMapPatchDataset(
for split in ["train", "validation"]:
dataset = MuMapPatchDataset(
self.dataset_dir, self.dataset_dir,
patches_per_image=self.params["patch_number"], patches_per_image=self.params["patch_number"],
patch_size=self.params["patch_size"], patch_size=self.params["patch_size"],
patch_offset=self.params["patch_offset"], patch_offset=self.params["patch_offset"],
shuffle=self.params["shuffle"] if split == "train" else False, shuffle=self.params["shuffle"],
split_name=split,
transform_normalization=transform_normalization, transform_normalization=transform_normalization,
scatter_correction=self.params["scatter_correction"], scatter_correction=self.params["scatter_correction"],
logger=logger, logger=logger,
) )
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=self.params["batch_size"],
shuffle=split == "train",
pin_memory=True,
num_workers=1,
)
data_loaders[split] = data_loader
self.logger.debug(f"Init discriminator ....") self.logger.debug(f"Init discriminator ....")
discriminator = Discriminator( discriminator = Discriminator(
...@@ -304,7 +395,7 @@ class RandomSearchCGAN(RandomSearch): ...@@ -304,7 +395,7 @@ class RandomSearchCGAN(RandomSearch):
if self.params["lr_decay"] if self.params["lr_decay"]
else None else None
) )
params_d = TrainingParams( params_d = DiscriminatorParams(
model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
) )
...@@ -324,14 +415,15 @@ class RandomSearchCGAN(RandomSearch): ...@@ -324,14 +415,15 @@ class RandomSearchCGAN(RandomSearch):
if self.params["lr_decay"] if self.params["lr_decay"]
else None else None
) )
params_g = TrainingParams( params_g = GeneratorParams(
model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
) )
self.logger.debug(f"Init training ....") self.logger.debug(f"Init training ....")
self.training = cGANTraining( self.training = cGANTraining(
data_loaders=data_loaders,
epochs=self.params["epochs"], epochs=self.params["epochs"],
dataset=dataset,
batch_size=self.params["batch_size"],
device=self.device, device=self.device,
snapshot_dir=snapshot_dir, snapshot_dir=snapshot_dir,
snapshot_epoch=self.params["epochs"], snapshot_epoch=self.params["epochs"],
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment