From 9c8ba6bbd22522c7b573a8ce5c868592cba17945 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 13 Jan 2023 13:46:30 +0100
Subject: [PATCH] fix remaining bug regarding random search

---
 mu_map/eval/measures.py          | 2 ++
 mu_map/training/random_search.py | 2 +-
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/mu_map/eval/measures.py b/mu_map/eval/measures.py
index ddacc24..ad3c29e 100644
--- a/mu_map/eval/measures.py
+++ b/mu_map/eval/measures.py
@@ -61,6 +61,8 @@ def compute_measures(dataset: MuMapDataset, model: UNet) -> pd.DataFrame:
     pd.DataFrame
         a dataframe containing containing the measures for each image in the dataset
     """
+    device = next(model.parameters()).device
+
     measures = {"NMAE": nmae, "MSE": mse}
     values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys())))
     for i, (recon, mu_map) in enumerate(dataset):
diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py
index 0688ef3..cc9b82d 100644
--- a/mu_map/training/random_search.py
+++ b/mu_map/training/random_search.py
@@ -374,7 +374,7 @@ class RandomSearchCGAN(RandomSearch):
             if nmae < self.nmae_min:
                 self.logger.info(f"New best run at iteration {i}")
                 self.nmae_min = nmae
-            self._cleanup_run(i, link_best=(nmae_min == nmae))
+            self._cleanup_run(i, link_best=(self.nmae_min == nmae))
         return self.nmae_min
 
     def eval_run(self):
-- 
GitLab