Skip to content
Snippets Groups Projects
lib.py 9.21 KiB
Newer Older
  • Learn to ignore specific revisions
  • """
    Module functioning as a library for training related code.
    """
    
    from dataclasses import dataclass
    
    import os
    
    from typing import Dict, List, Optional
    
    import numpy as np
    
    import torch
    from torch import Tensor
    
    from mu_map.dataset.default import MuMapDataset
    
    @dataclass
    class TrainingParams:
    
        """
        Dataclass to bundle parameters related to the optimization of
        a single model. This includes a name, the model itself and an
        optimizer. Optionally, a learning rate scheduler can be added.
        """
    
    
        name: str
        model: torch.nn.Module
        optimizer: torch.optim.Optimizer
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
    
    
    
        """
        Abstract implementation of a training.
        An implementation needs to overwrite the methods `_train_batch` and `_eval_batch`.
        In addition, training parameters for all models need to be added to the
        `self.training_params` list as this is used to put models in the according mode
        as well as using the optimizer and learning rate scheduler.
    
        This abstract class implement a common training procedure so that
        implementations can focus on the computations per batch and not iterating over
        the dataset, storing snapshots, etc.
    
    
        Parameters
        ----------
        epochs: int
            the number of epochs to train
        dataset: MuMapDataset
            the dataset to use for training
        batch_size: int
            the batch size used for training
        device: torch.device
            the device on which to perform computations (cpu or cuda)
        snapshot_dir: str
            the directory where snapshots are stored
        snapshot_epoch: int
            at each of these epochs a snapshot is stored
        early_stopping: int, optional
            if defined, training is stopped if the validation loss did not improve
            for this many epochs
        logger: Logger, optional
            optional logger to print results
    
        def __init__(
            self,
            epochs: int,
            dataset: MuMapDataset,
            batch_size: int,
            device: torch.device,
            snapshot_dir: str,
            snapshot_epoch: int,
    
            early_stopping: Optional[int],
    
        ):
            self.epochs = epochs
            self.batch_size = batch_size
            self.dataset = dataset
            self.device = device
    
            self.early_stopping = early_stopping
    
    
            self.snapshot_dir = snapshot_dir
            self.snapshot_epoch = snapshot_epoch
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.logger = (
                logger if logger is not None else get_logger(name=self.__class__.__name__)
            )
    
            self.training_params: List[TrainingParams] = []
    
            self.data_loaders = dict(
                [
                    (
                        split_name,
                        torch.utils.data.DataLoader(
    
                            dataset.copy(split_name),
    
                            batch_size=self.batch_size,
                            shuffle=True,
                            pin_memory=True,
                            num_workers=1,
                        ),
                    )
                    for split_name in ["train", "validation"]
                ]
            )
    
    
        def run(self) -> float:
    
            """
            Implementation of a training run.
    
            For each epoch:
              1. Train the model
              2. Evaluate the model on the validation split
              3. If applicable, store a snapshot
            The validation loss is also kept track of to keep a snapshot
            which achieves a minimal loss.
            """
    
            loss_val_min = sys.maxsize
    
    
            losses_val = [sys.maxsize]
    
            for epoch in range(1, self.epochs + 1):
                str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
                self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
    
                loss_train = self._train_epoch()
                self.logger.info(
                    f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
                )
                loss_val = self._eval_epoch()
                self.logger.info(
                    f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
                )
    
    
                if epoch % self.snapshot_epoch == 0:
                    self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")
    
                if loss_val < min(losses_val):
    
                    self.logger.info(
                        f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
                    )
                    self.store_snapshot("val_min")
    
                losses_val.append(loss_val)
    
                last_improvement = len(losses_val) - np.argmin(losses_val)
                if self.early_stopping and last_improvement > self.early_stopping:
                    self.logger.info(
                        f"Stop early because the last improvement was {last_improvement} epochs ago"
                    )
                    return min(losses_val)
    
    
                for param in self.training_params:
                    if param.lr_scheduler is not None:
                        param.lr_scheduler.step()
    
            return min(losses_val)
    
        def _after_train_batch(self):
            """
            Function called after the loss computation on a batch during training.
            It is responsible for stepping all optimizers.
            """
            for param in self.training_params:
                param.optimizer.step()
    
    
        def _train_epoch(self) -> float:
            """
            Implementation of the training in a single epoch.
    
                :return: a number representing the training loss
            """
            # activate gradients
    
            torch.set_grad_enabled(True)
    
            # set models into training mode
    
            for param in self.training_params:
                param.model.train()
    
    
            # iterate of all batches in the training dataset
    
            loss = 0.0
            data_loader = self.data_loaders["train"]
            for i, (inputs, targets) in enumerate(data_loader):
                print(
                    f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                    end="\r",
                )
    
    
                # move data to according device
    
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
    
    
                # zero grad optimizers
    
                for param in self.training_params:
                    param.optimizer.zero_grad()
    
    
                # step optimizers
    
            return loss / len(data_loader)
    
    
        def _eval_epoch(self) -> float:
            """
            Implementation of the evaluation in a single epoch.
    
            :return: a number representing the validation loss
            """
            # deactivate gradients
    
            torch.set_grad_enabled(False)
    
            # set models into evaluation mode
    
            for param in self.training_params:
                param.model.eval()
    
            # iterate of all batches in the validation dataset
    
            loss = 0.0
            data_loader = self.data_loaders["validation"]
            for i, (inputs, targets) in enumerate(data_loader):
                print(
                    f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                    end="\r",
                )
    
    
                # move data to according device
    
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
    
    
            return loss / len(data_loader)
    
        def store_snapshot(self, prefix: str):
    
            """
            Store snapshots of all models.
    
    
            Parameters
            ----------
            prefix: str
                prefix for all stored snapshot files
    
            for param in self.training_params:
    
                    self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
    
                self.logger.debug(f"Store snapshot at {snapshot_file}")
                torch.save(param.model.state_dict(), snapshot_file)
    
    
        def get_param_by_name(self, name: str) -> TrainingParams:
            """
            Get a training parameter by its name.
    
            Parameters
            ----------
            name: str
    
            Returns
            -------
            TrainingParams
    
            Raises
            ------
            ValueError
                if parameters cannot be found
            """
    
            _param = list(
                filter(
                    lambda training_param: training_param.name.lower() == name.lower(),
                    self.training_params,
                )
            )
    
            if len(_param) == 0:
                raise ValueError(f"Cannot find training_parameter with name {name}")
            return _param[0]
    
    
        def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
    
            """
            Implementation of training a single batch.
    
    
            Parameters
            ----------
            inputs: torch.Tensor
                batch of input data
            targets: torch.Tensor
                batch of target data
    
            Returns
            -------
            float
                a number representing the loss
    
            return 0
    
        def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
    
            """
            Implementation of evaluating a single batch.
    
    
            Parameters
            ----------
            inputs: torch.Tensor
                batch of input data
            targets: torch.Tensor
                batch of target data
    
            Returns
            -------
            float
                a number representing the loss
    
            return 0