Skip to content
Snippets Groups Projects
Commit 29db38c6 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

test a few different parameters for random search

parent cdc7ac16
No related branches found
No related tags found
No related merge requests found
......@@ -139,29 +139,30 @@ class RandomSearchCGAN(RandomSearch):
# dataset params
self.param_sampler["patch_size"] = ChoiceSampler([32])
self.param_sampler["patch_offset"] = IntIntervalSampler(0, 32)
self.param_sampler["patch_number"] = IntIntervalSampler(50, 200)
self.param_sampler["patch_offset"] = ChoiceSampler([0])
self.param_sampler["patch_number"] = ChoiceSampler([100])
self.param_sampler["scatter_correction"] = ChoiceSampler([True, False])
self.param_sampler["shuffle"] = ChoiceSampler([True, False])
self.param_sampler["shuffle"] = ChoiceSampler([False])
self.param_sampler["normalization"] = ChoiceSampler(
[MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()]
)
self.param_sampler["pad_crop"] = ChoiceSampler(
[None, PadCropTranform(dim=3, size=32)]
)
self.param_sampler["pad_crop"] = ChoiceSampler([PadCropTranform(dim=3, size=32)])
# training params
self.param_sampler["epochs"] = ChoiceSampler([100])
self.param_sampler["batch_size"] = ChoiceSampler([64])
self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
self.param_sampler["lr_decay"] = ChoiceSampler([True, False])
# self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001)
self.param_sampler["lr"] = ChoiceSampler([0.001])
self.param_sampler["lr_decay"] = ChoiceSampler([False])
self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1])
self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99])
self.param_sampler["criterion_dist"] = ChoiceSampler(
[WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")]
)
self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0)
self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0)
# self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0)
self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0])
# self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0)
self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0])
def run(self):
nmae_min = sys.maxsize
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment