diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index 1ae2b9fa374b70d87b3de7a16ff9f5216a74239f..6b9b0d9d130598a1f1fe76fad952e020c8d9c495 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -212,6 +212,28 @@ class AbstractTraining: self.logger.debug(f"Store snapshot at {snapshot_file}") torch.save(param.model.state_dict(), snapshot_file) + def get_param_by_name(self, name: str) -> TrainingParams: + """ + Get a training parameter by its name. + + Parameters + ---------- + name: str + + Returns + ------- + TrainingParams + + Raises + ------ + ValueError + if parameters cannot be found + """ + _param = list(filter(lambda training_param: training_param.name.lower() == name.lower(), self.training_params)) + if len(_param) == 0: + raise ValueError(f"Cannot find training_parameter with name {name}") + return _param[0] + def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: """ Implementation of training a single batch.