""" Module functioning as a library for training related code. """ from dataclasses import dataclass from logging import Logger import os import random from typing import Dict, List, Optional import sys import numpy as np import torch from torch import Tensor from mu_map.dataset.default import MuMapDataset from mu_map.logging import get_logger def init_random_seed(seed: Optional[int] = None) -> int: """ Set the seed for all RNGs (default python, numpy and torch). Parameters ---------- seed: int, optional the seed to be used which is generated if not provided Returns ------- int the randoms seed used """ seed = seed if seed is not None else random.randint(0, 2**32 - 1) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) return seed @dataclass class TrainingParams: """ Dataclass to bundle parameters related to the optimization of a single model. This includes a name, the model itself and an optimizer. Optionally, a learning rate scheduler can be added. """ name: str model: torch.nn.Module optimizer: torch.optim.Optimizer lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] class AbstractTraining: """ Abstract implementation of a training. An implementation needs to overwrite the methods `_train_batch` and `_eval_batch`. In addition, training parameters for all models need to be added to the `self.training_params` list as this is used to put models in the according mode as well as using the optimizer and learning rate scheduler. This abstract class implement a common training procedure so that implementations can focus on the computations per batch and not iterating over the dataset, storing snapshots, etc. Parameters ---------- epochs: int the number of epochs to train dataset: MuMapDataset the dataset to use for training batch_size: int the batch size used for training device: torch.device the device on which to perform computations (cpu or cuda) snapshot_dir: str the directory where snapshots are stored snapshot_epoch: int at each of these epochs a snapshot is stored early_stopping: int, optional if defined, training is stopped if the validation loss did not improve for this many epochs logger: Logger, optional optional logger to print results """ def __init__( self, epochs: int, dataset: MuMapDataset, batch_size: int, device: torch.device, snapshot_dir: str, snapshot_epoch: int, early_stopping: Optional[int] = None, logger: Optional[Logger] = None, ): self.epochs = epochs self.batch_size = batch_size self.dataset = dataset self.device = device self.early_stopping = early_stopping self.snapshot_dir = snapshot_dir self.snapshot_epoch = snapshot_epoch self.logger = ( logger if logger is not None else get_logger(name=self.__class__.__name__) ) self.training_params: List[TrainingParams] = [] self.data_loaders = dict( [ ( split_name, torch.utils.data.DataLoader( dataset.copy(split_name), batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=1, ), ) for split_name in ["train", "validation"] ] ) def run(self) -> float: """ Implementation of a training run. For each epoch: 1. Train the model 2. Evaluate the model on the validation split 3. If applicable, store a snapshot The validation loss is also kept track of to keep a snapshot which achieves a minimal loss. """ loss_val_min = sys.maxsize losses_val = [sys.maxsize] for epoch in range(1, self.epochs + 1): str_epoch = f"{str(epoch):>{len(str(self.epochs))}}" self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...") loss_train = self._train_epoch() self.logger.info( f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}" ) loss_val = self._eval_epoch() self.logger.info( f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}" ) if epoch % self.snapshot_epoch == 0: self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}") if loss_val < min(losses_val): self.logger.info( f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss" ) self.store_snapshot("val_min") losses_val.append(loss_val) last_improvement = len(losses_val) - np.argmin(losses_val) if self.early_stopping and last_improvement > self.early_stopping: self.logger.info( f"Stop early because the last improvement was {last_improvement} epochs ago" ) return min(losses_val) for param in self.training_params: if param.lr_scheduler is not None: param.lr_scheduler.step() return min(losses_val) def _after_train_batch(self): """ Function called after the loss computation on a batch during training. It is responsible for stepping all optimizers. """ for param in self.training_params: param.optimizer.step() def _train_epoch(self) -> float: """ Implementation of the training in a single epoch. :return: a number representing the training loss """ # activate gradients torch.set_grad_enabled(True) # set models into training mode for param in self.training_params: param.model.train() # iterate of all batches in the training dataset loss = 0.0 data_loader = self.data_loaders["train"] for i, (inputs, targets) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) # move data to according device inputs = inputs.to(self.device) targets = targets.to(self.device) # zero grad optimizers for param in self.training_params: param.optimizer.zero_grad() loss = loss + self._train_batch(inputs, targets) # step optimizers self._after_train_batch() return loss / len(data_loader) def _eval_epoch(self) -> float: """ Implementation of the evaluation in a single epoch. :return: a number representing the validation loss """ # deactivate gradients torch.set_grad_enabled(False) # set models into evaluation mode for param in self.training_params: param.model.eval() # iterate of all batches in the validation dataset loss = 0.0 data_loader = self.data_loaders["validation"] for i, (inputs, targets) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) # move data to according device inputs = inputs.to(self.device) targets = targets.to(self.device) loss = loss + self._eval_batch(inputs, targets) return loss / len(data_loader) def store_snapshot(self, prefix: str): """ Store snapshots of all models. Parameters ---------- prefix: str prefix for all stored snapshot files """ for param in self.training_params: snapshot_file = os.path.join( self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth" ) self.logger.debug(f"Store snapshot at {snapshot_file}") torch.save(param.model.state_dict(), snapshot_file) def get_param_by_name(self, name: str) -> TrainingParams: """ Get a training parameter by its name. Parameters ---------- name: str Returns ------- TrainingParams Raises ------ ValueError if parameters cannot be found """ _param = list( filter( lambda training_param: training_param.name.lower() == name.lower(), self.training_params, ) ) if len(_param) == 0: raise ValueError(f"Cannot find training_parameter with name {name}") return _param[0] def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: """ Implementation of training a single batch. Parameters ---------- inputs: torch.Tensor batch of input data targets: torch.Tensor batch of target data Returns ------- float a number representing the loss """ return 0 def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: """ Implementation of evaluating a single batch. Parameters ---------- inputs: torch.Tensor batch of input data targets: torch.Tensor batch of target data Returns ------- float a number representing the loss """ return 0