Skip to content
Snippets Groups Projects
Commit 68e49293 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

small changes to abstract training: optim.step can be overwritte and logger is truely optional

parent e3597787
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger
import os import os
from typing import Dict, Optional from typing import Dict, Optional
import sys import sys
...@@ -7,6 +8,7 @@ import torch ...@@ -7,6 +8,7 @@ import torch
from torch import Tensor from torch import Tensor
from mu_map.dataset.default import MuMapDataset from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger
@dataclass @dataclass
...@@ -26,7 +28,7 @@ class AbstractTraining: ...@@ -26,7 +28,7 @@ class AbstractTraining:
device: torch.device, device: torch.device,
snapshot_dir: str, snapshot_dir: str,
snapshot_epoch: int, snapshot_epoch: int,
logger, # TODO make optional? logger: Optional[Logger],
): ):
self.epochs = epochs self.epochs = epochs
self.batch_size = batch_size self.batch_size = batch_size
...@@ -36,7 +38,7 @@ class AbstractTraining: ...@@ -36,7 +38,7 @@ class AbstractTraining:
self.snapshot_dir = snapshot_dir self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch self.snapshot_epoch = snapshot_epoch
self.logger = logger self.logger = logger if logger is not None else get_logger(name=self.__class__.__name__)
self.training_params = [] self.training_params = []
self.data_loaders = dict( self.data_loaders = dict(
...@@ -85,6 +87,14 @@ class AbstractTraining: ...@@ -85,6 +87,14 @@ class AbstractTraining:
param.lr_scheduler.step() param.lr_scheduler.step()
return loss_val_min return loss_val_min
def _after_train_batch(self):
"""
Function called after the loss computation on a batch during training.
It is responsible for stepping all optimizers.
"""
for param in self.training_params:
param.optimizer.step()
def _train_epoch(self): def _train_epoch(self):
torch.set_grad_enabled(True) torch.set_grad_enabled(True)
for param in self.training_params: for param in self.training_params:
...@@ -106,8 +116,7 @@ class AbstractTraining: ...@@ -106,8 +116,7 @@ class AbstractTraining:
loss = loss + self._train_batch(inputs, targets) loss = loss + self._train_batch(inputs, targets)
for param in self.training_params: self._after_train_batch()
param.optimizer.step()
return loss / len(data_loader) return loss / len(data_loader)
def _eval_epoch(self): def _eval_epoch(self):
...@@ -132,7 +141,7 @@ class AbstractTraining: ...@@ -132,7 +141,7 @@ class AbstractTraining:
def store_snapshot(self, prefix: str): def store_snapshot(self, prefix: str):
for param in self.training_params: for param in self.training_params:
snapshot_file = os.path.join( snapshot_file = os.path.join(
self.snapshot_dir, f"{prefix}_{param.name}.pth" self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
) )
self.logger.debug(f"Store snapshot at {snapshot_file}") self.logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(param.model.state_dict(), snapshot_file) torch.save(param.model.state_dict(), snapshot_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment