Skip to content
Snippets Groups Projects
lib.py 9.71 KiB
"""
Module functioning as a library for training related code.
"""
from dataclasses import dataclass
from logging import Logger
import os
import random
from typing import Dict, List, Optional
import sys

import numpy as np
import torch
from torch import Tensor

from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger


def init_random_seed(seed: Optional[int] = None) -> int:
    """
    Set the seed for all RNGs (default python, numpy and torch).

    Parameters
    ----------
    seed: int, optional
        the seed to be used which is generated if not provided

    Returns
    -------
    int
        the randoms seed used
    """
    seed = seed if seed is not None else random.randint(0, 2**32 - 1)

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    return seed


@dataclass
class TrainingParams:
    """
    Dataclass to bundle parameters related to the optimization of
    a single model. This includes a name, the model itself and an
    optimizer. Optionally, a learning rate scheduler can be added.
    """

    name: str
    model: torch.nn.Module
    optimizer: torch.optim.Optimizer
    lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]


class AbstractTraining:
    """
    Abstract implementation of a training.
    An implementation needs to overwrite the methods `_train_batch` and `_eval_batch`.
    In addition, training parameters for all models need to be added to the
    `self.training_params` list as this is used to put models in the according mode
    as well as using the optimizer and learning rate scheduler.

    This abstract class implement a common training procedure so that
    implementations can focus on the computations per batch and not iterating over
    the dataset, storing snapshots, etc.

    Parameters
    ----------
    epochs: int
        the number of epochs to train
    dataset: MuMapDataset
        the dataset to use for training
    batch_size: int
        the batch size used for training
    device: torch.device
        the device on which to perform computations (cpu or cuda)
    snapshot_dir: str
        the directory where snapshots are stored
    snapshot_epoch: int
        at each of these epochs a snapshot is stored
    early_stopping: int, optional
        if defined, training is stopped if the validation loss did not improve
        for this many epochs
    logger: Logger, optional
        optional logger to print results
    """

    def __init__(
        self,
        epochs: int,
        dataset: MuMapDataset,
        batch_size: int,
        device: torch.device,
        snapshot_dir: str,
        snapshot_epoch: int,
        early_stopping: Optional[int] = None,
        logger: Optional[Logger] = None,
    ):
        self.epochs = epochs
        self.batch_size = batch_size
        self.dataset = dataset
        self.device = device
        self.early_stopping = early_stopping

        self.snapshot_dir = snapshot_dir
        self.snapshot_epoch = snapshot_epoch

        self.logger = (
            logger if logger is not None else get_logger(name=self.__class__.__name__)
        )

        self.training_params: List[TrainingParams] = []
        self.data_loaders = dict(
            [
                (
                    split_name,
                    torch.utils.data.DataLoader(
                        dataset.copy(split_name),
                        batch_size=self.batch_size,
                        shuffle=True,
                        pin_memory=True,
                        num_workers=1,
                    ),
                )
                for split_name in ["train", "validation"]
            ]
        )

    def run(self) -> float:
        """
        Implementation of a training run.

        For each epoch:
          1. Train the model
          2. Evaluate the model on the validation split
          3. If applicable, store a snapshot
        The validation loss is also kept track of to keep a snapshot
        which achieves a minimal loss.
        """
        loss_val_min = sys.maxsize

        losses_val = [sys.maxsize]
        for epoch in range(1, self.epochs + 1):
            str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
            self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")

            loss_train = self._train_epoch()
            self.logger.info(
                f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
            )
            loss_val = self._eval_epoch()
            self.logger.info(
                f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
            )

            if epoch % self.snapshot_epoch == 0:
                self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")

            if loss_val < min(losses_val):
                self.logger.info(
                    f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
                )
                self.store_snapshot("val_min")
            losses_val.append(loss_val)

            last_improvement = len(losses_val) - np.argmin(losses_val)
            if self.early_stopping and last_improvement > self.early_stopping:
                self.logger.info(
                    f"Stop early because the last improvement was {last_improvement} epochs ago"
                )
                return min(losses_val)

            for param in self.training_params:
                if param.lr_scheduler is not None:
                    param.lr_scheduler.step()
        return min(losses_val)

    def _after_train_batch(self):
        """
        Function called after the loss computation on a batch during training.
        It is responsible for stepping all optimizers.
        """
        for param in self.training_params:
            param.optimizer.step()

    def _train_epoch(self) -> float:
        """
        Implementation of the training in a single epoch.

            :return: a number representing the training loss
        """
        # activate gradients
        torch.set_grad_enabled(True)
        # set models into training mode
        for param in self.training_params:
            param.model.train()

        # iterate of all batches in the training dataset
        loss = 0.0
        data_loader = self.data_loaders["train"]
        for i, (inputs, targets) in enumerate(data_loader):
            print(
                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                end="\r",
            )

            # move data to according device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            # zero grad optimizers
            for param in self.training_params:
                param.optimizer.zero_grad()

            loss = loss + self._train_batch(inputs, targets)

            # step optimizers
            self._after_train_batch()
        return loss / len(data_loader)

    def _eval_epoch(self) -> float:
        """
        Implementation of the evaluation in a single epoch.

        :return: a number representing the validation loss
        """
        # deactivate gradients
        torch.set_grad_enabled(False)
        # set models into evaluation mode
        for param in self.training_params:
            param.model.eval()

        # iterate of all batches in the validation dataset
        loss = 0.0
        data_loader = self.data_loaders["validation"]
        for i, (inputs, targets) in enumerate(data_loader):
            print(
                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                end="\r",
            )

            # move data to according device
            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            loss = loss + self._eval_batch(inputs, targets)
        return loss / len(data_loader)

    def store_snapshot(self, prefix: str):
        """
        Store snapshots of all models.

        Parameters
        ----------
        prefix: str
            prefix for all stored snapshot files
        """
        for param in self.training_params:
            snapshot_file = os.path.join(
                self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
            )
            self.logger.debug(f"Store snapshot at {snapshot_file}")
            torch.save(param.model.state_dict(), snapshot_file)

    def get_param_by_name(self, name: str) -> TrainingParams:
        """
        Get a training parameter by its name.

        Parameters
        ----------
        name: str

        Returns
        -------
        TrainingParams

        Raises
        ------
        ValueError
            if parameters cannot be found
        """
        _param = list(
            filter(
                lambda training_param: training_param.name.lower() == name.lower(),
                self.training_params,
            )
        )
        if len(_param) == 0:
            raise ValueError(f"Cannot find training_parameter with name {name}")
        return _param[0]

    def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
        """
        Implementation of training a single batch.

        Parameters
        ----------
        inputs: torch.Tensor
            batch of input data
        targets: torch.Tensor
            batch of target data

        Returns
        -------
        float
            a number representing the loss
        """
        return 0

    def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
        """
        Implementation of evaluating a single batch.

        Parameters
        ----------
        inputs: torch.Tensor
            batch of input data
        targets: torch.Tensor
            batch of target data

        Returns
        -------
        float
            a number representing the loss
        """
        return 0