Skip to content
Snippets Groups Projects
Commit 6cff2d59 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement function to retrieve parameters by name

parent 19046346
No related branches found
No related tags found
No related merge requests found
......@@ -212,6 +212,28 @@ class AbstractTraining:
self.logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(param.model.state_dict(), snapshot_file)
def get_param_by_name(self, name: str) -> TrainingParams:
"""
Get a training parameter by its name.
Parameters
----------
name: str
Returns
-------
TrainingParams
Raises
------
ValueError
if parameters cannot be found
"""
_param = list(filter(lambda training_param: training_param.name.lower() == name.lower(), self.training_params))
if len(_param) == 0:
raise ValueError(f"Cannot find training_parameter with name {name}")
return _param[0]
def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
Implementation of training a single batch.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment