diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index ccb2eda79871685b398b6a0d2415554dae06b4df..b43f1bfd619311671a192bc6fdef021b32fb1e0e 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -112,8 +112,8 @@ class AbstractTraining: def _eval_epoch(self): torch.set_grad_enabled(False) - for model in self.models: - model.eval() + for param in self.training_params: + param.model.eval() loss = 0.0 data_loader = self.data_loaders["validation"]