Skip to content
Snippets Groups Projects
Commit 6029718e authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement early stopping in abstract training

parent 37279fe9
No related branches found
No related tags found
No related merge requests found
......@@ -7,6 +7,7 @@ import os
from typing import Dict, List, Optional
import sys
import numpy as np
import torch
from torch import Tensor
......@@ -55,6 +56,7 @@ class AbstractTraining:
dataset: MuMapDataset,
batch_size: int,
device: torch.device,
early_stopping: Optional[int],
snapshot_dir: str,
snapshot_epoch: int,
logger: Optional[Logger],
......@@ -63,6 +65,7 @@ class AbstractTraining:
self.batch_size = batch_size
self.dataset = dataset
self.device = device
self.early_stopping = early_stopping
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
......@@ -100,6 +103,8 @@ class AbstractTraining:
which achieves a minimal loss.
"""
loss_val_min = sys.maxsize
losses_val = [sys.maxsize]
for epoch in range(1, self.epochs + 1):
str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
......@@ -113,20 +118,27 @@ class AbstractTraining:
f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
)
if loss_val < loss_val_min:
loss_val_min = loss_val
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")
if loss_val < min(losses_val):
self.logger.info(
f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
)
self.store_snapshot("val_min")
losses_val.append(loss_val)
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")
last_improvement = len(losses_val) - np.argmin(losses_val)
if self.early_stopping and last_improvement > self.early_stopping:
self.logger.info(
f"Stop early because the last improvement was {last_improvement} epochs ago"
)
return min(losses_val)
for param in self.training_params:
if param.lr_scheduler is not None:
param.lr_scheduler.step()
return loss_val_min
return min(losses_val)
def _after_train_batch(self):
"""
......@@ -229,7 +241,12 @@ class AbstractTraining:
ValueError
if parameters cannot be found
"""
_param = list(filter(lambda training_param: training_param.name.lower() == name.lower(), self.training_params))
_param = list(
filter(
lambda training_param: training_param.name.lower() == name.lower(),
self.training_params,
)
)
if len(_param) == 0:
raise ValueError(f"Cannot find training_parameter with name {name}")
return _param[0]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment