From b7fad94bfb06a03ff3daecb88f66d2944ca896ac Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 6 Jan 2023 12:05:35 +0100 Subject: [PATCH] change parameters in random search --- mu_map/training/random_search.py | 48 ++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index fcba379..0b08c60 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -21,9 +21,9 @@ from mu_map.dataset.normalization import ( ) from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform from mu_map.eval.measures import nmae, mse -from mu_map.models.discriminator import Discriminator +from mu_map.models.discriminator import Discriminator, PatchDiscriminator from mu_map.models.unet import UNet -from mu_map.training.cgan2 import cGANTraining, DiscriminatorParams, GeneratorParams +from mu_map.training.cgan import cGANTraining, DiscriminatorParams, GeneratorParams from mu_map.training.loss import WeightedLoss from mu_map.logging import get_logger @@ -216,7 +216,7 @@ class RandomSearch: def validate_and_make_directory(_dir: str): """ - Uility method to validate that a directory exists and is empty. + Utility method to validate that a directory exists and is empty. If is does not exist, it is created. """ if not os.path.exists(_dir): @@ -237,7 +237,7 @@ class RandomSearchCGAN(RandomSearch): self.dataset_dir = "data/second" self.iterations = iterations - self.device = torch.device("cuda") + self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") self.n_slices = 32 self.params = {} @@ -256,10 +256,10 @@ class RandomSearchCGAN(RandomSearch): ) self.training: cGANTraining = None - # dataset params + # dataset parameters self.param_sampler["patch_size"] = ChoiceSampler([32, 64]) self.param_sampler["patch_offset"] = ChoiceSampler([0]) - self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=200) + self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100) self.param_sampler["scatter_correction"] = ChoiceSampler([False]) self.param_sampler["shuffle"] = ChoiceSampler([False, True]) self.param_sampler["normalization"] = ChoiceSampler( @@ -267,10 +267,24 @@ class RandomSearchCGAN(RandomSearch): ) self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)]) - # training params - self.param_sampler["epochs"] = IntIntervalSampler(min_val=50, max_val=200) - self.param_sampler["batch_size"] = ChoiceSampler([64]) - self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001) + # model parameters + self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"]) + def discriminator_conv_features(discriminator_type: str, **kwargs): + if discriminator_type == "class": + return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]] + else: + return [[32, 64, 128, 256], [64, 128, 256, 512]] + self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features) + self.param_sampler["generator_features"] = ChoiceSampler([ + [128, 256, 512], + [64, 128, 256, 512], + [32, 64, 128, 256, 512], + ]) + + # training parameters + self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90]) + self.param_sampler["batch_size"] = ChoiceSampler([32, 64]) + self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001) self.param_sampler["lr_decay"] = ChoiceSampler([False, True]) self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1]) self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99]) @@ -379,10 +393,15 @@ class RandomSearchCGAN(RandomSearch): ) self.logger.debug(f"Init discriminator ....") - discriminator = Discriminator( - in_channels=2, input_size=(self.n_slices, self.params["patch_size"], self.params["patch_size"]) - ) + input_size = (2, self.n_slices, self.params["patch_size"], self.params["patch_size"]) + if self.params["discriminator_type"] == "class": + discriminator = Discriminator( + in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"], + ) + else: + discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"]) discriminator = discriminator.to(self.device) + optimizer = torch.optim.Adam( discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999) ) @@ -400,8 +419,7 @@ class RandomSearchCGAN(RandomSearch): ) self.logger.debug(f"Init generator ....") - features = [64, 128, 256, 512] - generator = UNet(in_channels=1, features=features) + generator = UNet(in_channels=1, features=self.params["generator_features"]) generator = generator.to(self.device) optimizer = torch.optim.Adam( generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999) -- GitLab