Newer
Older

Tamino Huxohl
committed
from logging import Logger

Tamino Huxohl
committed
from typing import Dict, Optional
import sys
import torch
from torch import Tensor
from mu_map.dataset.default import MuMapDataset

Tamino Huxohl
committed
from mu_map.logging import get_logger

Tamino Huxohl
committed
@dataclass
class TrainingParams:
name: str
model: torch.nn.Module
optimizer: torch.optim.Optimizer
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]

Tamino Huxohl
committed
class AbstractTraining:
def __init__(
self,
epochs: int,
dataset: MuMapDataset,
batch_size: int,
device: torch.device,
snapshot_dir: str,
snapshot_epoch: int,

Tamino Huxohl
committed
logger: Optional[Logger],
):
self.epochs = epochs
self.batch_size = batch_size
self.dataset = dataset
self.device = device
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = (
logger if logger is not None else get_logger(name=self.__class__.__name__)
)

Tamino Huxohl
committed
self.data_loaders = dict(
[
(
split_name,
torch.utils.data.DataLoader(
dataset.split_copy(split_name),
batch_size=self.batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
),
)
for split_name in ["train", "validation"]
]
)
def run(self) -> float:
loss_val_min = sys.maxsize
for epoch in range(1, self.epochs + 1):
str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
loss_train = self._train_epoch()
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
)
loss_val = self._eval_epoch()
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
)
if loss_val < loss_val_min:
loss_val_min = loss_val
self.logger.info(
f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
)
self.store_snapshot("val_min")
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")
for param in self.training_params:
if param.lr_scheduler is not None:
param.lr_scheduler.step()
return loss_val_min

Tamino Huxohl
committed
def _after_train_batch(self):
"""
Function called after the loss computation on a batch during training.
It is responsible for stepping all optimizers.
"""
for param in self.training_params:
param.optimizer.step()
def _train_epoch(self):
torch.set_grad_enabled(True)
for param in self.training_params:
param.model.train()
loss = 0.0
data_loader = self.data_loaders["train"]
for i, (inputs, targets) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
targets = targets.to(self.device)
for param in self.training_params:
param.optimizer.zero_grad()

Tamino Huxohl
committed
loss = loss + self._train_batch(inputs, targets)

Tamino Huxohl
committed
self._after_train_batch()
for param in self.training_params:
param.model.eval()
loss = 0.0
data_loader = self.data_loaders["validation"]
for i, (inputs, targets) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
targets = targets.to(self.device)

Tamino Huxohl
committed
loss = loss + self._eval_batch(inputs, targets)
return loss / len(data_loader)
def store_snapshot(self, prefix: str):
for param in self.training_params:

Tamino Huxohl
committed
snapshot_file = os.path.join(

Tamino Huxohl
committed
self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"

Tamino Huxohl
committed
)
self.logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(param.model.state_dict(), snapshot_file)
def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
return 0
def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
return 0