class Training():

    def __init__(self, epochs):
        self.epochs = epochs

    def run(self):
        for epoch in range(1, self.epochs + 1):
            self.run_epoch(self.data_loader["train"], phase="train")
            loss_training = self.run_epoch(self.data_loader["train"], phase="eval")
            loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval")

            # ToDo: log outputs and time
            self.lr_scheduler.step()

            if epoch % self.snapshot_epoch:
                self.store_snapshot(epoch)

            

    def run_epoch(self, data_loader, phase):
        self.model.train() if phase == "train" else self.model.eval()

        epoch_loss = 0
        for inputs, labels in self.data_loader:
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                outputs = self.model(inputs)
                loss = self.loss(outputs, labels)

                if phase == "train":
                    loss.backward()
                    optimizer.step()

            epoch_loss += loss.item() / inputs.size[0]
        return epoch_loss


    def store_snapshot(self, epoch):
        pass