From 37279fe9f7366c0b1c7494f37e0d012bfe9ce102 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 13 Jan 2023 09:46:12 +0100
Subject: [PATCH] random search uses new measure computation method

---
 mu_map/training/random_search.py | 27 +++------------------------
 1 file changed, 3 insertions(+), 24 deletions(-)

diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py
index 82f877b..c6c68e3 100644
--- a/mu_map/training/random_search.py
+++ b/mu_map/training/random_search.py
@@ -20,7 +20,7 @@ from mu_map.dataset.normalization import (
     GaussianNormTransform,
 )
 from mu_map.dataset.transform import PadCropTranform, Transform, SequenceTransform
-from mu_map.eval.measures import nmae, mse
+from mu_map.eval.measures import compute_measures
 from mu_map.models.discriminator import Discriminator, PatchDiscriminator
 from mu_map.models.unet import UNet
 from mu_map.training.cgan import cGANTraining, DiscriminatorParams, GeneratorParams
@@ -337,28 +337,7 @@ class RandomSearchCGAN(RandomSearch):
             scatter_correction=self.params["scatter_correction"],
         )
 
-        measures = {"NMAE": nmae, "MSE": mse}
-        values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
-        for i, (recon, mu_map) in enumerate(dataset):
-            print(
-                f"Process input {str(i):>{len(str(len(dataset)))}}/{len(dataset)}",
-                end="\r",
-            )
-            prediction = model(recon.unsqueeze(dim=0).to(self.device))
-            prediction = prediction.squeeze().cpu().numpy()
-            mu_map = mu_map.squeeze().cpu().numpy()
-
-            row = pd.DataFrame(
-                dict(
-                    map(
-                        lambda item: (item[0], [item[1](prediction, mu_map)]),
-                        measures.items(),
-                    )
-                )
-            )
-            values = pd.concat((values, row), ignore_index=True)
-        print(f" " * 100, end="\r")
-
+        values = compute_measures(dataset, model)
         values.to_csv(os.path.join(self.dir_train, "measures.csv"), index=False)
         return values["NMAE"].mean()
 
@@ -470,5 +449,5 @@ class RandomSearchCGAN(RandomSearch):
 
 
 if __name__ == "__main__":
-    random_search = RandomSearchCGAN(iterations=10)
+    random_search = RandomSearchCGAN(iterations=50)
     random_search.run()
-- 
GitLab