diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index 82f877b714f074e2d821a17e4a7805803d5b24d5..c6c68e36fc76761f7d17d9766c9e57327863e9a8 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -20,7 +20,7 @@ from mu_map.dataset.normalization import ( GaussianNormTransform, ) from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform -from mu_map.eval.measures import nmae, mse +from mu_map.eval.measures import compute_measures from mu_map.models.discriminator import Discriminator, PatchDiscriminator from mu_map.models.unet import UNet from mu_map.training.cgan import cGANTraining, DiscriminatorParams, GeneratorParams @@ -337,28 +337,7 @@ class RandomSearchCGAN(RandomSearch): scatter_correction=self.params["scatter_correction"], ) - measures = {"NMAE": nmae, "MSE": mse} - values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys()))) - for i, (recon, mu_map) in enumerate(dataset): - print( - f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}", - end="\r", - ) - prediction = model(recon.unsqueeze(dim=0).to(self.device)) - prediction = prediction.squeeze().cpu().numpy() - mu_map = mu_map.squeeze().cpu().numpy() - - row = pd.DataFrame( - dict( - map( - lambda item: (item[0], [item[1](prediction, mu_map)]), - measures.items(), - ) - ) - ) - values = pd.concat((values, row), ignore_index=True) - print(f" " * 100, end="\r") - + values = compute_measures(dataset, model) values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False) return values["NMAE"].mean() @@ -470,5 +449,5 @@ class RandomSearchCGAN(RandomSearch): if __name__ == "__main__": - random_search = RandomSearchCGAN(iterations=10) + random_search = RandomSearchCGAN(iterations=50) random_search.run()