from dataclasses import dataclass import os from typing import Dict, Optional import sys import torch from torch import Tensor from mu_map.dataset.default import MuMapDataset @dataclass class TrainingParams: name: str model: torch.nn.Module optimizer: torch.optim.Optimizer lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] class AbstractTraining: def __init__( self, epochs: int, dataset: MuMapDataset, batch_size: int, device: torch.device, snapshot_dir: str, snapshot_epoch: int, logger, # TODO make optional? ): self.epochs = epochs self.batch_size = batch_size self.dataset = dataset self.device = device self.snapshot_dir = snapshot_dir self.snapshot_epoch = snapshot_epoch self.logger = logger self.training_params = [] self.data_loaders = dict( [ ( split_name, torch.utils.data.DataLoader( dataset.split_copy(split_name), batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=1, ), ) for split_name in ["train", "validation"] ] ) def run(self) -> float: loss_val_min = sys.maxsize for epoch in range(1, self.epochs + 1): str_epoch = f"{str(epoch):>{len(str(self.epochs))}}" self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...") loss_train = self._train_epoch() self.logger.info( f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}" ) loss_val = self._eval_epoch() self.logger.info( f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}" ) if loss_val < loss_val_min: loss_val_min = loss_val self.logger.info( f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss" ) self.store_snapshot("val_min") if epoch % self.snapshot_epoch == 0: self._store_snapshot(f"{epoch:0d{len(str(self.epochs))}}") for param in self.training_params: if param.lr_scheduler is not None: param.lr_scheduler.step() return loss_val_min def _train_epoch(self): torch.set_grad_enabled(True) for param in self.training_params: param.model.train() loss = 0.0 data_loader = self.data_loaders["train"] for i, (inputs, targets) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) inputs = inputs.to(self.device) targets = targets.to(self.device) for param in self.training_params: param.optimizer.zero_grad() loss = loss + self._train_batch(inputs, targets) for param in self.training_params: param.optimizer.step() return loss / len(data_loader) def _eval_epoch(self, phase: str): torch.set_grad_enabled(False) for model in self.models: model.eval() loss = 0.0 data_loader = self.data_loaders["validation"] for i, (inputs, targets) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) inputs = inputs.to(self.device) targets = targets.to(self.device) loss = loss + self._eval_batch(inputs, targets) return loss / len(data_loader) def store_snapshot(self, prefix: str): for param in self.training_params: snapshot_file = os.path.join( self.snapshot_dir, f"{prefix}_{param.name}.pth" ) self.logger.debug(f"Store snapshot at {snapshot_file}") torch.save(param.model.state_dict(), snapshot_file) def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: return 0 def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: return 0