From 6cff2d595f978d839b6362f1e03dd4567229afd7 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Thu, 5 Jan 2023 10:35:35 +0100 Subject: [PATCH] implement function to retrieve parameters by name --- mu_map/training/lib.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index 1ae2b9f..6b9b0d9 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -212,6 +212,28 @@ class AbstractTraining: self.logger.debug(f"Store snapshot at {snapshot_file}") torch.save(param.model.state_dict(), snapshot_file) + def get_param_by_name(self, name: str) -> TrainingParams: + """ + Get a training parameter by its name. + + Parameters + ---------- + name: str + + Returns + ------- + TrainingParams + + Raises + ------ + ValueError + if parameters cannot be found + """ + _param = list(filter(lambda training_param: training_param.name.lower() == name.lower(), self.training_params)) + if len(_param) == 0: + raise ValueError(f"Cannot find training_parameter with name {name}") + return _param[0] + def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: """ Implementation of training a single batch. -- GitLab