Skip to content
Snippets Groups Projects
default.py 1.26 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    
    class Training():
    
        def __init__(self, epochs):
            self.epochs = epochs
    
        def run(self):
            for epoch in range(1, self.epochs + 1):
                self.run_epoch(self.data_loader["train"], phase="train")
                loss_training = self.run_epoch(self.data_loader["train"], phase="eval")
                loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval")
    
                # ToDo: log outputs and time
                self.lr_scheduler.step()
    
                if epoch % self.snapshot_epoch:
                    self.store_snapshot(epoch)
    
                
    
        def run_epoch(self, data_loader, phase):
            self.model.train() if phase == "train" else self.model.eval()
    
            epoch_loss = 0
            for inputs, labels in self.data_loader:
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
    
                self.optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    outputs = self.model(inputs)
                    loss = self.loss(outputs, labels)
    
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
    
                epoch_loss += loss.item() / inputs.size[0]
            return epoch_loss
    
    
        def store_snapshot(self, epoch):
            pass