class Training(): def __init__(self, epochs): self.epochs = epochs def run(self): for epoch in range(1, self.epochs + 1): self.run_epoch(self.data_loader["train"], phase="train") loss_training = self.run_epoch(self.data_loader["train"], phase="eval") loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval") # ToDo: log outputs and time self.lr_scheduler.step() if epoch % self.snapshot_epoch: self.store_snapshot(epoch) def run_epoch(self, data_loader, phase): self.model.train() if phase == "train" else self.model.eval() epoch_loss = 0 for inputs, labels in self.data_loader: inputs = inputs.to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() with torch.set_grad_enabled(phase == "train"): outputs = self.model(inputs) loss = self.loss(outputs, labels) if phase == "train": loss.backward() optimizer.step() epoch_loss += loss.item() / inputs.size[0] return epoch_loss def store_snapshot(self, epoch): pass