Skip to content
Snippets Groups Projects
Commit 68e49293 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

small changes to abstract training: optim.step can be overwritte and logger is truely optional

parent e3597787
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass
from logging import Logger
import os
from typing import Dict, Optional
import sys
......@@ -7,6 +8,7 @@ import torch
from torch import Tensor
from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger
@dataclass
......@@ -26,7 +28,7 @@ class AbstractTraining:
device: torch.device,
snapshot_dir: str,
snapshot_epoch: int,
logger, # TODO make optional?
logger: Optional[Logger],
):
self.epochs = epochs
self.batch_size = batch_size
......@@ -36,7 +38,7 @@ class AbstractTraining:
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger
self.logger = logger if logger is not None else get_logger(name=self.__class__.__name__)
self.training_params = []
self.data_loaders = dict(
......@@ -85,6 +87,14 @@ class AbstractTraining:
param.lr_scheduler.step()
return loss_val_min
def _after_train_batch(self):
"""
Function called after the loss computation on a batch during training.
It is responsible for stepping all optimizers.
"""
for param in self.training_params:
param.optimizer.step()
def _train_epoch(self):
torch.set_grad_enabled(True)
for param in self.training_params:
......@@ -106,8 +116,7 @@ class AbstractTraining:
loss = loss + self._train_batch(inputs, targets)
for param in self.training_params:
param.optimizer.step()
self._after_train_batch()
return loss / len(data_loader)
def _eval_epoch(self):
......@@ -132,7 +141,7 @@ class AbstractTraining:
def store_snapshot(self, prefix: str):
for param in self.training_params:
snapshot_file = os.path.join(
self.snapshot_dir, f"{prefix}_{param.name}.pth"
self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
)
self.logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(param.model.state_dict(), snapshot_file)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment