Skip to content
Snippets Groups Projects
Commit b7fad94b authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

change parameters in random search

parent 730f3623
No related branches found
No related tags found
No related merge requests found
......@@ -21,9 +21,9 @@ from mu_map.dataset.normalization import (
)
from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
from mu_map.eval.measures import nmae, mse
from mu_map.models.discriminator import Discriminator
from mu_map.models.discriminator import Discriminator, PatchDiscriminator
from mu_map.models.unet import UNet
from mu_map.training.cgan2 import cGANTraining, DiscriminatorParams, GeneratorParams
from mu_map.training.cgan import cGANTraining, DiscriminatorParams, GeneratorParams
from mu_map.training.loss import WeightedLoss
from mu_map.logging import get_logger
......@@ -216,7 +216,7 @@ class RandomSearch:
def validate_and_make_directory(_dir: str):
"""
Uility method to validate that a directory exists and is empty.
Utility method to validate that a directory exists and is empty.
If is does not exist, it is created.
"""
if not os.path.exists(_dir):
......@@ -237,7 +237,7 @@ class RandomSearchCGAN(RandomSearch):
self.dataset_dir = "data/second"
self.iterations = iterations
self.device = torch.device("cuda")
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.n_slices = 32
self.params = {}
......@@ -256,10 +256,10 @@ class RandomSearchCGAN(RandomSearch):
)
self.training: cGANTraining = None
# dataset params
# dataset parameters
self.param_sampler["patch_size"] = ChoiceSampler([32, 64])
self.param_sampler["patch_offset"] = ChoiceSampler([0])
self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=200)
self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=100)
self.param_sampler["scatter_correction"] = ChoiceSampler([False])
self.param_sampler["shuffle"] = ChoiceSampler([False, True])
self.param_sampler["normalization"] = ChoiceSampler(
......@@ -267,10 +267,24 @@ class RandomSearchCGAN(RandomSearch):
)
self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)])
# training params
self.param_sampler["epochs"] = IntIntervalSampler(min_val=50, max_val=200)
self.param_sampler["batch_size"] = ChoiceSampler([64])
self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
# model parameters
self.param_sampler["discriminator_type"] = ChoiceSampler(["class", "patch"])
def discriminator_conv_features(discriminator_type: str, **kwargs):
if discriminator_type == "class":
return [[32, 64, 128], [64, 128, 256], [32, 64, 128, 256]]
else:
return [[32, 64, 128, 256], [64, 128, 256, 512]]
self.param_sampler["discriminator_conv_features"] = DependentChoiceSampler(discriminator_conv_features)
self.param_sampler["generator_features"] = ChoiceSampler([
[128, 256, 512],
[64, 128, 256, 512],
[32, 64, 128, 256, 512],
])
# training parameters
self.param_sampler["epochs"] = ChoiceSampler([50, 60, 70, 80, 90])
self.param_sampler["batch_size"] = ChoiceSampler([32, 64])
self.param_sampler["lr"] = FloatIntervalSampler(0.1, 0.0001)
self.param_sampler["lr_decay"] = ChoiceSampler([False, True])
self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
......@@ -379,10 +393,15 @@ class RandomSearchCGAN(RandomSearch):
)
self.logger.debug(f"Init discriminator ....")
discriminator = Discriminator(
in_channels=2, input_size=(self.n_slices, self.params["patch_size"], self.params["patch_size"])
)
input_size = (2, self.n_slices, self.params["patch_size"], self.params["patch_size"])
if self.params["discriminator_type"] == "class":
discriminator = Discriminator(
in_channels=2, input_size=input_size, conv_features=self.params["discriminator_conv_features"],
)
else:
discriminator = PatchDiscriminator(in_channels=2, features=self.params["discriminator_conv_features"])
discriminator = discriminator.to(self.device)
optimizer = torch.optim.Adam(
discriminator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
)
......@@ -400,8 +419,7 @@ class RandomSearchCGAN(RandomSearch):
)
self.logger.debug(f"Init generator ....")
features = [64, 128, 256, 512]
generator = UNet(in_channels=1, features=features)
generator = UNet(in_channels=1, features=self.params["generator_features"])
generator = generator.to(self.device)
optimizer = torch.optim.Adam(
generator.parameters(), lr=self.params["lr"], betas=(0.5, 0.999)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment