Skip to content
Snippets Groups Projects
Commit 7bf09ddd authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add missing training lib module

parent e0d463db
No related branches found
No related tags found
No related merge requests found
from dataclasses import dataclass
import os
from typing import Dict
import torch
from torch import Tensor
from mu_map.dataset.default import MuMapDataset
@dataclass
class TrainingParams:
name: str
model: torch.nn.Module
optimizer: torch.optim.Optimizer
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
class AbstractTraining:
def __init__(
self,
epochs: int,
dataset: MuMapDataset,
batch_size: int,
device: torch.device,
snapshot_dir: str,
snapshot_epoch: int,
logger=None,
):
self.epochs = epochs
self.batch_size = batch_size
self.dataset = dataset
self.device = device
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger
self.training_params = []
def run(self) -> float:
loss_val_min = sys.maxsize
for epoch in range(1, self.epochs + 1):
str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
loss_train = self._train_epoch()
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
)
loss_val = self._eval_epoch()
self.logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
)
if loss_val < loss_val_min:
loss_val_min = loss_val
self.logger.info(
f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
)
self.store_snapshot("val_min")
if epoch % self.snapshot_epoch == 0:
self._store_snapshot(f"{epoch:0d{len(str(self.epochs))}}")
for param in self.training_params:
if param.lr_scheduler is not None:
param.lr_scheduler.step()
return loss_val_min
def _train_epoch(self):
torch.set_grad_enabled(True)
for param in self.training_params:
param.model.train()
loss = 0.0
data_loader = self.data_loaders["train"]
for i, (inputs, targets) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
targets = targets.to(self.device)
for param in self.training_params:
param.optimizer.zero_grad()
loss = loss + self._train_batch(self, inputs, targets)
for param in self.training_params:
param.optimizer.step()
return loss / len(data_loader)
def _eval_epoch(self, phase: str):
torch.set_grad_enabled(False)
for model in self.models:
model.eval()
loss = 0.0
data_loader = self.data_loaders["validation"]
for i, (inputs, targets) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
targets = targets.to(self.device)
loss = loss + self._eval_batch(self, inputs, targets)
return loss / len(data_loader)
def store_snapshot(self, prefix: str):
for param in self.training_params:
snapshot_file = os.path.join(self.snapshot_dir, f"{prefix}_{param.name}.pth")
self.logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(param.model.state_dict(), snapshot_file)
def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
return 0
def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
return 0
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment