From 9c8ba6bbd22522c7b573a8ce5c868592cba17945 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 13 Jan 2023 13:46:30 +0100 Subject: [PATCH] fix remaining bug regarding random search --- mu_map/eval/measures.py | 2 ++ mu_map/training/random_search.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mu_map/eval/measures.py b/mu_map/eval/measures.py index ddacc24..ad3c29e 100644 --- a/mu_map/eval/measures.py +++ b/mu_map/eval/measures.py @@ -61,6 +61,8 @@ def compute_measures(dataset: MuMapDataset, model: UNet) -> pd.DataFrame: pd.DataFrame a dataframe containing containing the measures for each image in the dataset """ + device = next(model.parameters()).device + measures = {"NMAE": nmae, "MSE": mse} values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys()))) for i, (recon, mu_map) in enumerate(dataset): diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index 0688ef3..cc9b82d 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -374,7 +374,7 @@ class RandomSearchCGAN(RandomSearch): if nmae < self.nmae_min: self.logger.info(f"New best run at iteration {i}") self.nmae_min = nmae - self._cleanup_run(i, link_best=(nmae_min == nmae)) + self._cleanup_run(i, link_best=(self.nmae_min == nmae)) return self.nmae_min def eval_run(self): -- GitLab