From 6cff2d595f978d839b6362f1e03dd4567229afd7 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 5 Jan 2023 10:35:35 +0100
Subject: [PATCH] implement function to retrieve parameters by name

---
 mu_map/training/lib.py | 22 ++++++++++++++++++++++
 1 file changed, 22 insertions(+)

diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py
index 1ae2b9f..6b9b0d9 100644
--- a/mu_map/training/lib.py
+++ b/mu_map/training/lib.py
@@ -212,6 +212,28 @@ class AbstractTraining:
             self.logger.debug(f"Store snapshot at {snapshot_file}")
             torch.save(param.model.state_dict(), snapshot_file)
 
+    def get_param_by_name(self, name: str) -> TrainingParams:
+        """
+        Get a training parameter by its name.
+
+        Parameters
+        ----------
+        name: str
+
+        Returns
+        -------
+        TrainingParams
+
+        Raises
+        ------
+        ValueError
+            if parameters cannot be found
+        """
+        _param = list(filter(lambda training_param: training_param.name.lower() == name.lower(), self.training_params))
+        if len(_param) == 0:
+            raise ValueError(f"Cannot find training_parameter with name {name}")
+        return _param[0]
+
     def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
         """
         Implementation of training a single batch.
-- 
GitLab