diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index 5f87abadd00bf7237ddcc4b3072b2f7566ff8eab..fcba37948490a04f49e241f3fb9b2b3b69cbea3a 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -237,11 +237,13 @@ class RandomSearchCGAN(RandomSearch): self.dataset_dir = "data/second" self.iterations = iterations - self.dir = "cgan_random_search" - validate_and_make_directory(self.dir) self.device = torch.device("cuda") + self.n_slices = 32 self.params = {} + self.dir = "cgan_random_search" + validate_and_make_directory(self.dir) + self.dir_train = os.path.join(self.dir, "train_data") self.logger = ( logger @@ -258,28 +260,24 @@ class RandomSearchCGAN(RandomSearch): self.param_sampler["patch_size"] = ChoiceSampler([32, 64]) self.param_sampler["patch_offset"] = ChoiceSampler([0]) self.param_sampler["patch_number"] = IntIntervalSampler(min_val=50, max_val=200) - # self.param_sampler["scatter_correction"] = ChoiceSampler([True, False]) self.param_sampler["scatter_correction"] = ChoiceSampler([False]) self.param_sampler["shuffle"] = ChoiceSampler([False, True]) self.param_sampler["normalization"] = ChoiceSampler( [MeanNormTransform(), MaxNormTransform(), GaussianNormTransform()] ) - self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=32)]) + self.param_sampler["pad_crop"] = ChoiceSampler([None, PadCropTranform(dim=3, size=self.n_slices)]) # training params self.param_sampler["epochs"] = IntIntervalSampler(min_val=50, max_val=200) self.param_sampler["batch_size"] = ChoiceSampler([64]) self.param_sampler["lr"] = FloatIntervalSampler(0.01, 0.0001) - # self.param_sampler["lr"] = ChoiceSampler([0.001]) self.param_sampler["lr_decay"] = ChoiceSampler([False, True]) self.param_sampler["lr_decay_epoch"] = ChoiceSampler([1]) self.param_sampler["lr_decay_factor"] = ChoiceSampler([0.99]) self.param_sampler["criterion_dist"] = ChoiceSampler( [WeightedLoss.from_str("L1"), WeightedLoss.from_str("L2+GDL")] ) - # self.param_sampler["weight_criterion_dist"] = FloatIntervalSampler(1.0, 100.0) self.param_sampler["weight_criterion_dist"] = ChoiceSampler([1.0, 20.0, 100.0]) - # self.param_sampler["weight_criterion_adv"] = FloatIntervalSampler(1.0, 100.0) self.param_sampler["weight_criterion_adv"] = ChoiceSampler([1.0, 20.0, 100.0]) def run(self): @@ -305,17 +303,16 @@ class RandomSearchCGAN(RandomSearch): return nmae_min def eval_run(self): - return random.randint(0, 200) self.logger.debug("Perform evaluation ...") torch.set_grad_enabled(False) weights_file = os.path.join(self.training.snapshot_dir, "val_min_generator.pth") self.logger.debug(f"Load weights from {weights_file}") - model = self.training.params_g.model.eval() + model = self.training.generator.model.eval() model.load_state_dict(torch.load(weights_file, map_location=self.device)) transform_normalization = SequenceTransform( - [self.params["normalization"], PadCropTranform(dim=3, size=32)] + [self.params["normalization"], PadCropTranform(dim=3, size=self.n_slices)] ) dataset = MuMapDataset( self.dataset_dir, @@ -373,6 +370,7 @@ class RandomSearchCGAN(RandomSearch): self.dataset_dir, patches_per_image=self.params["patch_number"], patch_size=self.params["patch_size"], + patch_size_z=self.n_slices, patch_offset=self.params["patch_offset"], shuffle=self.params["shuffle"], transform_normalization=transform_normalization, @@ -382,7 +380,7 @@ class RandomSearchCGAN(RandomSearch): self.logger.debug(f"Init discriminator ....") discriminator = Discriminator( - in_channels=2, input_size=self.params["patch_size"] + in_channels=2, input_size=(self.n_slices, self.params["patch_size"], self.params["patch_size"]) ) discriminator = discriminator.to(self.device) optimizer = torch.optim.Adam(