import os

import torch

class Training():

    def __init__(self, model, data_loaders, epochs, logger):
        self.model = model
        self.data_loaders = data_loaders
        self.epochs = epochs
        self.device = torch.device("cpu")
        self.snapshot_dir = "tmp"
        self.snapshot_epoch = 5
        self.loss_func = torch.nn.MSELoss()

        # self.lr = 1e-3
        # self.lr_decay_factor = 0.99
        self.lr = 0.1
        self.lr_decay_factor = 0.5
        self.lr_decay_epoch = 1

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor)

        self.logger = logger


    def run(self):
        for epoch in range(1, self.epochs + 1):
            logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...")
            self._run_epoch(self.data_loaders["train"], phase="train")

            loss_training = self._run_epoch(self.data_loaders["train"], phase="eval")
            logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}")
            loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval")
            logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}")

            # ToDo: log outputs and time
            _previous = self.lr_scheduler.get_last_lr()[0]
            self.lr_scheduler.step()
            logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}")

            if epoch % self.snapshot_epoch:
                self.store_snapshot(epoch)

            logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}")

            

    def _run_epoch(self, data_loader, phase):
        logger.debug(f"Run epoch in phase {phase}")
        self.model.train() if phase == "train" else self.model.eval()

        epoch_loss = 0
        for i, (inputs, labels) in enumerate(data_loader):
            print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r")
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                outputs = self.model(inputs)
                loss = self.loss_func(outputs, labels)

                if phase == "train":
                    loss.backward()
                    self.optimizer.step()

            epoch_loss += loss.item() / inputs.shape[0]
        return epoch_loss


    def store_snapshot(self, epoch):
        snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
        snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
        logger.debug(f"Store snapshot at {snapshot_file}")
        torch.save(self.model.state_dict(), snapshot_file)


if __name__ == "__main__":
    from mu_map.data.mock import MuMapMockDataset
    from mu_map.logging import get_logger
    from mu_map.models.unet import UNet

    logger = get_logger(logfile="train.log", loglevel="DEBUG")

    model = UNet(in_channels=1, features=[8, 16])
    print(model)
    dataset = MuMapMockDataset()
    data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
    data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
    data_loaders = {"train": data_loader_train, "validation": data_loader_val}

    training = Training(model, data_loaders, 10, logger)
    training.run()