import os
from typing import Dict

import torch

from mu_map.logging import get_logger


class Training:
    def __init__(
        self,
        model: torch.nn.Module,
        data_loaders: Dict[str, torch.utils.data.DataLoader],
        epochs: int,
        device: torch.device,
        lr: float,
        lr_decay_factor: float,
        lr_decay_epoch: int,
        snapshot_dir: str,
        snapshot_epoch: int,
        logger=None,
    ):
        self.model = model
        self.data_loaders = data_loaders
        self.epochs = epochs
        self.device = device

        self.lr = lr
        self.lr_decay_factor = lr_decay_factor
        self.lr_decay_epoch = lr_decay_epoch

        self.snapshot_dir = snapshot_dir
        self.snapshot_epoch = snapshot_epoch

        self.logger = logger if logger is not None else get_logger()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
        )
        self.loss_func = torch.nn.MSELoss(reduction="mean")


    def run(self):
        for epoch in range(1, self.epochs + 1):
            logger.debug(
                f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
            )
            self._run_epoch(self.data_loaders["train"], phase="train")

            loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
            logger.info(
                f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}"
            )
            loss_validation = self._run_epoch(
                self.data_loaders["validation"], phase="val"
            )
            logger.info(
                f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}"
            )

            # ToDo: log outputs and time
            _previous = self.lr_scheduler.get_last_lr()[0]
            self.lr_scheduler.step()
            logger.debug(
                f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}"
            )

            if epoch % self.snapshot_epoch == 0:
                self.store_snapshot(epoch)

            logger.debug(
                f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
            )

    def _run_epoch(self, data_loader, phase):
        logger.debug(f"Run epoch in phase {phase}")
        self.model.train() if phase == "train" else self.model.eval()

        epoch_loss = 0
        loss_updates = 0
        for i, (inputs, labels) in enumerate(data_loader):
            print(
                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                end="\r",
            )
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):
                outputs = self.model(inputs)
                loss = self.loss_func(outputs, labels)

                if phase == "train":
                    loss.backward()
                    self.optimizer.step()

            epoch_loss += loss.item()
            loss_updates += 1
        return epoch_loss / loss_updates

    def store_snapshot(self, epoch):
        snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
        snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
        logger.debug(f"Store snapshot at {snapshot_file}")
        torch.save(self.model.state_dict(), snapshot_file)


if __name__ == "__main__":
    import argparse

    from mu_map.dataset.mock import MuMapMockDataset
    from mu_map.logging import add_logging_args, get_logger_by_args
    from mu_map.models.unet import UNet

    parser = argparse.ArgumentParser(
        description="Train a UNet model to predict μ-maps from reconstructed scatter images",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Model Args
    parser.add_argument(
        "--features",
        type=int,
        nargs="+",
        default=[8, 16],
        help="number of features in the layers of the UNet structure",
    )

    # Dataset Args
    # parser.add_argument("--features", type=int, nargs="+", default=[8, 16], help="number of features in the layers of the UNet structure")

    # Training Args
    parser.add_argument(
        "--output_dir",
        type=str,
        default="train_data",
        help="directory in which results (snapshots and logs) of this training are saved",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        help="the number of epochs for which the model is trained",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:0" if torch.cuda.is_available() else "cpu",
        help="the device (cpu or gpu) with which the training is performed",
    )
    parser.add_argument(
        "--lr", type=float, default=0.1, help="the initial learning rate for training"
    )
    parser.add_argument(
        "--lr_decay_factor",
        type=float,
        default=0.99,
        help="decay factor for the learning rate",
    )
    parser.add_argument(
        "--lr_decay_epoch",
        type=int,
        default=1,
        help="frequency in epochs at which the learning rate is decayed",
    )
    parser.add_argument(
        "--snapshot_dir",
        type=str,
        default="snapshots",
        help="directory under --output_dir where snapshots are stored",
    )
    parser.add_argument(
        "--snapshot_epoch",
        type=int,
        default=10,
        help="frequency in epochs at which snapshots are stored",
    )

    # Logging Args
    add_logging_args(parser, defaults={"--logfile": "train.log"})

    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
    if not os.path.exists(args.snapshot_dir):
        os.mkdir(args.snapshot_dir)

    args.logfile = os.path.join(args.output_dir, args.logfile)

    device = torch.device(args.device)
    logger = get_logger_by_args(args)

    model = UNet(in_channels=1, features=args.features)
    dataset = MuMapMockDataset(logger=logger)
    data_loader_train = torch.utils.data.DataLoader(
        dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
    )
    data_loader_val = torch.utils.data.DataLoader(
        dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
    )
    data_loaders = {"train": data_loader_train, "validation": data_loader_val}

    training = Training(
        model=model,
        data_loaders=data_loaders,
        epochs=args.epochs,
        device=device,
        lr=args.lr,
        lr_decay_factor=args.lr_decay_factor,
        lr_decay_epoch=args.lr_decay_epoch,
        snapshot_dir=args.snapshot_dir,
        snapshot_epoch=args.snapshot_epoch,
        logger=logger,
    )
    training.run()