Newer
Older
"""
Module functioning as a library for training related code.
"""

Tamino Huxohl
committed
from logging import Logger

Tamino Huxohl
committed
import sys
import torch
from torch import Tensor
from mu_map.dataset.default import MuMapDataset

Tamino Huxohl
committed
from mu_map.logging import get_logger

Tamino Huxohl
committed
"""
Dataclass to bundle parameters related to the optimization of
a single model. This includes a name, the model itself and an
optimizer. Optionally, a learning rate scheduler can be added.
"""
name: str
model: torch.nn.Module
optimizer: torch.optim.Optimizer
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]

Tamino Huxohl
committed
class AbstractTraining:
"""
Abstract implementation of a training.
An implementation needs to overwrite the methods `_train_batch` and `_eval_batch`.
In addition, training parameters for all models need to be added to the
`self.training_params` list as this is used to put models in the according mode
as well as using the optimizer and learning rate scheduler.
This abstract class implement a common training procedure so that
implementations can focus on the computations per batch and not iterating over
the dataset, storing snapshots, etc.
Parameters
----------
epochs: int
the number of epochs to train
dataset: MuMapDataset
the dataset to use for training
batch_size: int
the batch size used for training
device: torch.device
the device on which to perform computations (cpu or cuda)
snapshot_dir: str
the directory where snapshots are stored
snapshot_epoch: int
at each of these epochs a snapshot is stored
early_stopping: int, optional
if defined, training is stopped if the validation loss did not improve
for this many epochs
logger: Logger, optional
optional logger to print results
def __init__(
self,
epochs: int,
dataset: MuMapDataset,
batch_size: int,
device: torch.device,
snapshot_dir: str,
snapshot_epoch: int,

Tamino Huxohl
committed
logger: Optional[Logger],
):
self.epochs = epochs
self.batch_size = batch_size
self.dataset = dataset
self.device = device
self.early_stopping = early_stopping
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = (
logger if logger is not None else get_logger(name=self.__class__.__name__)
)
self.training_params: List[TrainingParams] = []

Tamino Huxohl
committed
self.data_loaders = dict(
[
(
split_name,
torch.utils.data.DataLoader(

Tamino Huxohl
committed
batch_size=self.batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
),
)
for split_name in ["train", "validation"]
]
)
"""
Implementation of a training run.
For each epoch:
1. Train the model
2. Evaluate the model on the validation split
3. If applicable, store a snapshot
The validation loss is also kept track of to keep a snapshot
which achieves a minimal loss.
"""
losses_val = [sys.maxsize]
for epoch in range(1, self.epochs + 1):
str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
loss_train = self._train_epoch()
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
)
loss_val = self._eval_epoch()
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
)
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")
if loss_val < min(losses_val):
self.logger.info(
f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
)
self.store_snapshot("val_min")
losses_val.append(loss_val)
last_improvement = len(losses_val) - np.argmin(losses_val)
if self.early_stopping and last_improvement > self.early_stopping:
self.logger.info(
f"Stop early because the last improvement was {last_improvement} epochs ago"
)
return min(losses_val)
for param in self.training_params:
if param.lr_scheduler is not None:
param.lr_scheduler.step()

Tamino Huxohl
committed
def _after_train_batch(self):
"""
Function called after the loss computation on a batch during training.
It is responsible for stepping all optimizers.
"""
for param in self.training_params:
param.optimizer.step()
def _train_epoch(self) -> float:
"""
Implementation of the training in a single epoch.
:return: a number representing the training loss
"""
# activate gradients
for param in self.training_params:
param.model.train()
# iterate of all batches in the training dataset
loss = 0.0
data_loader = self.data_loaders["train"]
for i, (inputs, targets) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
targets = targets.to(self.device)
for param in self.training_params:
param.optimizer.zero_grad()

Tamino Huxohl
committed
loss = loss + self._train_batch(inputs, targets)

Tamino Huxohl
committed
self._after_train_batch()
def _eval_epoch(self) -> float:
"""
Implementation of the evaluation in a single epoch.
:return: a number representing the validation loss
"""
# deactivate gradients
for param in self.training_params:
param.model.eval()
# iterate of all batches in the validation dataset
loss = 0.0
data_loader = self.data_loaders["validation"]
for i, (inputs, targets) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
targets = targets.to(self.device)

Tamino Huxohl
committed
loss = loss + self._eval_batch(inputs, targets)
return loss / len(data_loader)
def store_snapshot(self, prefix: str):
Parameters
----------
prefix: str
prefix for all stored snapshot files

Tamino Huxohl
committed
snapshot_file = os.path.join(

Tamino Huxohl
committed
self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"

Tamino Huxohl
committed
)
self.logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(param.model.state_dict(), snapshot_file)
def get_param_by_name(self, name: str) -> TrainingParams:
"""
Get a training parameter by its name.
Parameters
----------
name: str
Returns
-------
TrainingParams
Raises
------
ValueError
if parameters cannot be found
"""
_param = list(
filter(
lambda training_param: training_param.name.lower() == name.lower(),
self.training_params,
)
)
if len(_param) == 0:
raise ValueError(f"Cannot find training_parameter with name {name}")
return _param[0]
def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
Implementation of training a single batch.
Parameters
----------
inputs: torch.Tensor
batch of input data
targets: torch.Tensor
batch of target data
Returns
-------
float
a number representing the loss
return 0
def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
Implementation of evaluating a single batch.
Parameters
----------
inputs: torch.Tensor
batch of input data
targets: torch.Tensor
batch of target data
Returns
-------
float
a number representing the loss