diff --git a/mu_map/eval/measures.py b/mu_map/eval/measures.py index ddacc2415ed734fbe9212e02e70147cb61bb8030..ad3c29ecc6624ba0e2480a15ec1417b28c9e2092 100644 --- a/mu_map/eval/measures.py +++ b/mu_map/eval/measures.py @@ -61,6 +61,8 @@ def compute_measures(dataset: MuMapDataset, model: UNet) -> pd.DataFrame: pd.DataFrame a dataframe containing containing the measures for each image in the dataset """ + device = next(model.parameters()).device + measures = {"NMAE": nmae, "MSE": mse} values = pd.DataFrame(dict(map(lambda x: (x, []), measures.keys()))) for i, (recon, mu_map) in enumerate(dataset): diff --git a/mu_map/training/random_search.py b/mu_map/training/random_search.py index 0688ef39d3b25576302115b09a327d8e15ad74e7..cc9b82d917079640f98ac114816042be7bb55bcb 100644 --- a/mu_map/training/random_search.py +++ b/mu_map/training/random_search.py @@ -374,7 +374,7 @@ class RandomSearchCGAN(RandomSearch): if nmae < self.nmae_min: self.logger.info(f"New best run at iteration {i}") self.nmae_min = nmae - self._cleanup_run(i, link_best=(nmae_min == nmae)) + self._cleanup_run(i, link_best=(self.nmae_min == nmae)) return self.nmae_min def eval_run(self):