From 68e49293753c0c9108a37c33ba056613b7b50152 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Wed, 4 Jan 2023 11:17:33 +0100 Subject: [PATCH] small changes to abstract training: optim.step can be overwritte and logger is truely optional --- mu_map/training/lib.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/mu_map/training/lib.py b/mu_map/training/lib.py index e820c07..eee9a09 100644 --- a/mu_map/training/lib.py +++ b/mu_map/training/lib.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from logging import Logger import os from typing import Dict, Optional import sys @@ -7,6 +8,7 @@ import torch from torch import Tensor from mu_map.dataset.default import MuMapDataset +from mu_map.logging import get_logger @dataclass @@ -26,7 +28,7 @@ class AbstractTraining: device: torch.device, snapshot_dir: str, snapshot_epoch: int, - logger, # TODO make optional? + logger: Optional[Logger], ): self.epochs = epochs self.batch_size = batch_size @@ -36,7 +38,7 @@ class AbstractTraining: self.snapshot_dir = snapshot_dir self.snapshot_epoch = snapshot_epoch - self.logger = logger + self.logger = logger if logger is not None else get_logger(name=self.__class__.__name__) self.training_params = [] self.data_loaders = dict( @@ -85,6 +87,14 @@ class AbstractTraining: param.lr_scheduler.step() return loss_val_min + def _after_train_batch(self): + """ + Function called after the loss computation on a batch during training. + It is responsible for stepping all optimizers. + """ + for param in self.training_params: + param.optimizer.step() + def _train_epoch(self): torch.set_grad_enabled(True) for param in self.training_params: @@ -106,8 +116,7 @@ class AbstractTraining: loss = loss + self._train_batch(inputs, targets) - for param in self.training_params: - param.optimizer.step() + self._after_train_batch() return loss / len(data_loader) def _eval_epoch(self): @@ -132,7 +141,7 @@ class AbstractTraining: def store_snapshot(self, prefix: str): for param in self.training_params: snapshot_file = os.path.join( - self.snapshot_dir, f"{prefix}_{param.name}.pth" + self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth" ) self.logger.debug(f"Store snapshot at {snapshot_file}") torch.save(param.model.state_dict(), snapshot_file) -- GitLab