From 7bf09ddd13a8e4ad6db15eaa5b1e8ae01d83285e Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Thu, 15 Dec 2022 10:11:02 +0100 Subject: [PATCH] add missing training lib module --- mu_map/training/lib.py | 127 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 mu_map/training/lib.py diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py new file mode 100644 index 0000000..00811cd --- /dev/null +++ b/mu_map/training/lib.py @@ -0,0 +1,127 @@ +from dataclasses import dataclass +import os +from typing import Dict + +import torch +from torch import Tensor + +from mu_map.dataset.default import MuMapDataset + +@dataclass +class TrainingParams: + name: str + model: torch.nn.Module + optimizer: torch.optim.Optimizer + lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] + +class AbstractTraining: + + def __init__( + self, + epochs: int, + dataset: MuMapDataset, + batch_size: int, + device: torch.device, + snapshot_dir: str, + snapshot_epoch: int, + logger=None, + ): + self.epochs = epochs + self.batch_size = batch_size + self.dataset = dataset + self.device = device + + self.snapshot_dir = snapshot_dir + self.snapshot_epoch = snapshot_epoch + + self.logger = logger + + self.training_params = [] + + + def run(self) -> float: + loss_val_min = sys.maxsize + for epoch in range(1, self.epochs + 1): + str_epoch = f"{str(epoch):>{len(str(self.epochs))}}" + self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...") + + loss_train = self._train_epoch() + self.logger.info( + f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}" + ) + loss_val = self._eval_epoch() + self.logger.info( + f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}" + ) + + if loss_val < loss_val_min: + loss_val_min = loss_val + self.logger.info( + f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss" + ) + self.store_snapshot("val_min") + + if epoch % self.snapshot_epoch == 0: + self._store_snapshot(f"{epoch:0d{len(str(self.epochs))}}") + + for param in self.training_params: + if param.lr_scheduler is not None: + param.lr_scheduler.step() + return loss_val_min + + def _train_epoch(self): + torch.set_grad_enabled(True) + for param in self.training_params: + param.model.train() + + loss = 0.0 + data_loader = self.data_loaders["train"] + for i, (inputs, targets) in enumerate(data_loader): + print( + f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", + end="\r", + ) + + inputs = inputs.to(self.device) + targets = targets.to(self.device) + + for param in self.training_params: + param.optimizer.zero_grad() + + loss = loss + self._train_batch(self, inputs, targets) + + for param in self.training_params: + param.optimizer.step() + return loss / len(data_loader) + + def _eval_epoch(self, phase: str): + torch.set_grad_enabled(False) + for model in self.models: + model.eval() + + loss = 0.0 + data_loader = self.data_loaders["validation"] + for i, (inputs, targets) in enumerate(data_loader): + print( + f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", + end="\r", + ) + + inputs = inputs.to(self.device) + targets = targets.to(self.device) + + loss = loss + self._eval_batch(self, inputs, targets) + return loss / len(data_loader) + + def store_snapshot(self, prefix: str): + for param in self.training_params: + snapshot_file = os.path.join(self.snapshot_dir, f"{prefix}_{param.name}.pth") + self.logger.debug(f"Store snapshot at {snapshot_file}") + torch.save(param.model.state_dict(), snapshot_file) + + def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: + return 0 + + def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float: + return 0 + -- GitLab