-
Tamino Huxohl authoredTamino Huxohl authored
lib.py 9.71 KiB
"""
Module functioning as a library for training related code.
"""
from dataclasses import dataclass
from logging import Logger
import os
import random
from typing import Dict, List, Optional
import sys
import numpy as np
import torch
from torch import Tensor
from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger
def init_random_seed(seed: Optional[int] = None) -> int:
"""
Set the seed for all RNGs (default python, numpy and torch).
Parameters
----------
seed: int, optional
the seed to be used which is generated if not provided
Returns
-------
int
the randoms seed used
"""
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
return seed
@dataclass
class TrainingParams:
"""
Dataclass to bundle parameters related to the optimization of
a single model. This includes a name, the model itself and an
optimizer. Optionally, a learning rate scheduler can be added.
"""
name: str
model: torch.nn.Module
optimizer: torch.optim.Optimizer
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
class AbstractTraining:
"""
Abstract implementation of a training.
An implementation needs to overwrite the methods `_train_batch` and `_eval_batch`.
In addition, training parameters for all models need to be added to the
`self.training_params` list as this is used to put models in the according mode
as well as using the optimizer and learning rate scheduler.
This abstract class implement a common training procedure so that
implementations can focus on the computations per batch and not iterating over
the dataset, storing snapshots, etc.
Parameters
----------
epochs: int
the number of epochs to train
dataset: MuMapDataset
the dataset to use for training
batch_size: int
the batch size used for training
device: torch.device
the device on which to perform computations (cpu or cuda)
snapshot_dir: str
the directory where snapshots are stored
snapshot_epoch: int
at each of these epochs a snapshot is stored
early_stopping: int, optional
if defined, training is stopped if the validation loss did not improve
for this many epochs
logger: Logger, optional
optional logger to print results
"""
def __init__(
self,
epochs: int,
dataset: MuMapDataset,
batch_size: int,
device: torch.device,
snapshot_dir: str,
snapshot_epoch: int,
early_stopping: Optional[int] = None,
logger: Optional[Logger] = None,
):
self.epochs = epochs
self.batch_size = batch_size
self.dataset = dataset
self.device = device
self.early_stopping = early_stopping
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = (
logger if logger is not None else get_logger(name=self.__class__.__name__)
)
self.training_params: List[TrainingParams] = []
self.data_loaders = dict(
[
(
split_name,
torch.utils.data.DataLoader(
dataset.copy(split_name),
batch_size=self.batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
),
)
for split_name in ["train", "validation"]
]
)
def run(self) -> float:
"""
Implementation of a training run.
For each epoch:
1. Train the model
2. Evaluate the model on the validation split
3. If applicable, store a snapshot
The validation loss is also kept track of to keep a snapshot
which achieves a minimal loss.
"""
loss_val_min = sys.maxsize
losses_val = [sys.maxsize]
for epoch in range(1, self.epochs + 1):
str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
loss_train = self._train_epoch()
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
)
loss_val = self._eval_epoch()
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
)
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")
if loss_val < min(losses_val):
self.logger.info(
f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
)
self.store_snapshot("val_min")
losses_val.append(loss_val)
last_improvement = len(losses_val) - np.argmin(losses_val)
if self.early_stopping and last_improvement > self.early_stopping:
self.logger.info(
f"Stop early because the last improvement was {last_improvement} epochs ago"
)
return min(losses_val)
for param in self.training_params:
if param.lr_scheduler is not None:
param.lr_scheduler.step()
return min(losses_val)
def _after_train_batch(self):
"""
Function called after the loss computation on a batch during training.
It is responsible for stepping all optimizers.
"""
for param in self.training_params:
param.optimizer.step()
def _train_epoch(self) -> float:
"""
Implementation of the training in a single epoch.
:return: a number representing the training loss
"""
# activate gradients
torch.set_grad_enabled(True)
# set models into training mode
for param in self.training_params:
param.model.train()
# iterate of all batches in the training dataset
loss = 0.0
data_loader = self.data_loaders["train"]
for i, (inputs, targets) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
# move data to according device
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# zero grad optimizers
for param in self.training_params:
param.optimizer.zero_grad()
loss = loss + self._train_batch(inputs, targets)
# step optimizers
self._after_train_batch()
return loss / len(data_loader)
def _eval_epoch(self) -> float:
"""
Implementation of the evaluation in a single epoch.
:return: a number representing the validation loss
"""
# deactivate gradients
torch.set_grad_enabled(False)
# set models into evaluation mode
for param in self.training_params:
param.model.eval()
# iterate of all batches in the validation dataset
loss = 0.0
data_loader = self.data_loaders["validation"]
for i, (inputs, targets) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
# move data to according device
inputs = inputs.to(self.device)
targets = targets.to(self.device)
loss = loss + self._eval_batch(inputs, targets)
return loss / len(data_loader)
def store_snapshot(self, prefix: str):
"""
Store snapshots of all models.
Parameters
----------
prefix: str
prefix for all stored snapshot files
"""
for param in self.training_params:
snapshot_file = os.path.join(
self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
)
self.logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(param.model.state_dict(), snapshot_file)
def get_param_by_name(self, name: str) -> TrainingParams:
"""
Get a training parameter by its name.
Parameters
----------
name: str
Returns
-------
TrainingParams
Raises
------
ValueError
if parameters cannot be found
"""
_param = list(
filter(
lambda training_param: training_param.name.lower() == name.lower(),
self.training_params,
)
)
if len(_param) == 0:
raise ValueError(f"Cannot find training_parameter with name {name}")
return _param[0]
def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
Implementation of training a single batch.
Parameters
----------
inputs: torch.Tensor
batch of input data
targets: torch.Tensor
batch of target data
Returns
-------
float
a number representing the loss
"""
return 0
def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
"""
Implementation of evaluating a single batch.
Parameters
----------
inputs: torch.Tensor
batch of input data
targets: torch.Tensor
batch of target data
Returns
-------
float
a number representing the loss
"""
return 0