Skip to content
Snippets Groups Projects
Commit 2e31fc27 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

write doc for training lib

parent bc02d82e
No related branches found
No related tags found
No related merge requests found
"""
Module functioning as a library for training related code.
"""
from dataclasses import dataclass
from logging import Logger
import os
from typing import Dict, Optional
from typing import Dict, List, Optional
import sys
import torch
......@@ -13,6 +16,12 @@ from mu_map.logging import get_logger
@dataclass
class TrainingParams:
"""
Dataclass to bundle parameters related to the optimization of
a single model. This includes a name, the model itself and an
optimizer. Optionally, a learning rate scheduler can be added.
"""
name: str
model: torch.nn.Module
optimizer: torch.optim.Optimizer
......@@ -20,6 +29,26 @@ class TrainingParams:
class AbstractTraining:
"""
Abstract implementation of a training.
An implementation needs to overwrite the methods `_train_batch` and `_eval_batch`.
In addition, training parameters for all models need to be added to the
`self.training_params` list as this is used to put models in the according mode
as well as using the optimizer and learning rate scheduler.
This abstract class implement a common training procedure so that
implementations can focus on the computations per batch and not iterating over
the dataset, storing snapshots, etc.
:param epochs: the number of epochs to train
:param dataset: the dataset to use for training
:param batch_size: the batch size used for training
:param device: the device on which to perform computations (cpu or cuda)
:param snapshot_dir: the directory where snapshots are stored
:param snapshot_epoch: at each of these epochs a snapshot is stored
:param logger: optional logger to print results
"""
def __init__(
self,
epochs: int,
......@@ -42,7 +71,7 @@ class AbstractTraining:
logger if logger is not None else get_logger(name=self.__class__.__name__)
)
self.training_params = []
self.training_params: List[TrainingParams] = []
self.data_loaders = dict(
[
(
......@@ -60,6 +89,16 @@ class AbstractTraining:
)
def run(self) -> float:
"""
Implementation of a training run.
For each epoch:
1. Train the model
2. Evaluate the model on the validation split
3. If applicable, store a snapshot
The validation loss is also kept track of to keep a snapshot
which achieves a minimal loss.
"""
loss_val_min = sys.maxsize
for epoch in range(1, self.epochs + 1):
str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
......@@ -97,11 +136,19 @@ class AbstractTraining:
for param in self.training_params:
param.optimizer.step()
def _train_epoch(self):
def _train_epoch(self) -> float:
"""
Implementation of the training in a single epoch.
:return: a number representing the training loss
"""
# activate gradients
torch.set_grad_enabled(True)
# set models into training mode
for param in self.training_params:
param.model.train()
# iterate of all batches in the training dataset
loss = 0.0
data_loader = self.data_loaders["train"]
for i, (inputs, targets) in enumerate(data_loader):
......@@ -110,22 +157,33 @@ class AbstractTraining:
end="\r",
)
# move data to according device
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# zero grad optimizers
for param in self.training_params:
param.optimizer.zero_grad()
loss = loss + self._train_batch(inputs, targets)
# step optimizers
self._after_train_batch()
return loss / len(data_loader)
def _eval_epoch(self):
def _eval_epoch(self) -> float:
"""
Implementation of the evaluation in a single epoch.
:return: a number representing the validation loss
"""
# deactivate gradients
torch.set_grad_enabled(False)
# set models into evaluation mode
for param in self.training_params:
param.model.eval()
# iterate of all batches in the validation dataset
loss = 0.0
data_loader = self.data_loaders["validation"]
for i, (inputs, targets) in enumerate(data_loader):
......@@ -134,6 +192,7 @@ class AbstractTraining:
end="\r",
)
# move data to according device
inputs = inputs.to(self.device)
targets = targets.to(self.device)
......@@ -141,6 +200,11 @@ class AbstractTraining:
return loss / len(data_loader)
def store_snapshot(self, prefix: str):
"""
Store snapshots of all models.
:param prefix: prefix for all stored snapshot files
"""
for param in self.training_params:
snapshot_file = os.path.join(
self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
......@@ -149,7 +213,21 @@ class AbstractTraining:
torch.save(param.model.state_dict(), snapshot_file)
def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
Implementation of training a single batch.
:param inputs: batch of input data
:param targets: batch of target data
:return: a number representing the loss
"""
return 0
def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
Implementation of evaluating a single batch.
:param inputs: batch of input data
:param targets: batch of target data
:return: a number representing the loss
"""
return 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment