from dataclasses import dataclass from logging import Logger import os from typing import Dict, Optional import sys import torch from torch import Tensor from mu_map.dataset.default import MuMapDataset from mu_map.logging import get_logger @dataclass class TrainingParams: name: str model: torch.nn.Module optimizer: torch.optim.Optimizer lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] class AbstractTraining: def __init__( self, epochs: int, dataset: MuMapDataset, batch_size: int, device: torch.device, snapshot_dir: str, snapshot_epoch: int, logger: Optional[Logger], ): self.epochs = epochs self.batch_size = batch_size self.dataset = dataset self.device = device self.snapshot_dir = snapshot_dir self.snapshot_epoch = snapshot_epoch self.logger = ( logger if logger is not None else get_logger(name=self.__class__.__name__) ) self.training_params = [] self.data_loaders = dict( [ ( split_name, torch.utils.data.DataLoader( dataset.split_copy(split_name), batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=1, ), ) for split_name in ["train", "validation"] ] ) def run(self) -> float: loss_val_min = sys.maxsize for epoch in range(1, self.epochs + 1): str_epoch = f"{str(epoch):>{len(str(self.epochs))}}" self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...") loss_train = self._train_epoch() self.logger.info( f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}" ) loss_val = self._eval_epoch() self.logger.info( f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}" ) if loss_val < loss_val_min: loss_val_min = loss_val self.logger.info( f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss" ) self.store_snapshot("val_min") if epoch % self.snapshot_epoch == 0: self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}") for param in self.training_params: if param.lr_scheduler is not None: param.lr_scheduler.step() return loss_val_min def _after_train_batch(self): """ Function called after the loss computation on a batch during training. It is responsible for stepping all optimizers. """ for param in self.training_params: param.optimizer.step() def _train_epoch(self): torch.set_grad_enabled(True) for param in self.training_params: param.model.train() loss = 0.0 data_loader = self.data_loaders["train"] for i, (inputs, targets) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) inputs = inputs.to(self.device) targets = targets.to(self.device) for param in self.training_params: param.optimizer.zero_grad() loss = loss + self._train_batch(inputs, targets) self._after_train_batch() return loss / len(data_loader) def _eval_epoch(self): torch.set_grad_enabled(False) for param in self.training_params: param.model.eval() loss = 0.0 data_loader = self.data_loaders["validation"] for i, (inputs, targets) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) inputs = inputs.to(self.device) targets = targets.to(self.device) loss = loss + self._eval_batch(inputs, targets) return loss / len(data_loader) def store_snapshot(self, prefix: str): for param in self.training_params: snapshot_file = os.path.join( self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth" ) self.logger.debug(f"Store snapshot at {snapshot_file}") torch.save(param.model.state_dict(), snapshot_file) def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: return 0 def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: return 0