from dataclasses import dataclass
from logging import Logger
import os
from typing import Dict, Optional
import sys

import torch
from torch import Tensor

from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger


@dataclass
class TrainingParams:
    name: str
    model: torch.nn.Module
    optimizer: torch.optim.Optimizer
    lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]


class AbstractTraining:
    def __init__(
        self,
        epochs: int,
        dataset: MuMapDataset,
        batch_size: int,
        device: torch.device,
        snapshot_dir: str,
        snapshot_epoch: int,
        logger: Optional[Logger],
    ):
        self.epochs = epochs
        self.batch_size = batch_size
        self.dataset = dataset
        self.device = device

        self.snapshot_dir = snapshot_dir
        self.snapshot_epoch = snapshot_epoch

        self.logger = (
            logger if logger is not None else get_logger(name=self.__class__.__name__)
        )

        self.training_params = []
        self.data_loaders = dict(
            [
                (
                    split_name,
                    torch.utils.data.DataLoader(
                        dataset.split_copy(split_name),
                        batch_size=self.batch_size,
                        shuffle=True,
                        pin_memory=True,
                        num_workers=1,
                    ),
                )
                for split_name in ["train", "validation"]
            ]
        )

    def run(self) -> float:
        loss_val_min = sys.maxsize
        for epoch in range(1, self.epochs + 1):
            str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
            self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")

            loss_train = self._train_epoch()
            self.logger.info(
                f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
            )
            loss_val = self._eval_epoch()
            self.logger.info(
                f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
            )

            if loss_val < loss_val_min:
                loss_val_min = loss_val
                self.logger.info(
                    f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
                )
                self.store_snapshot("val_min")

            if epoch % self.snapshot_epoch == 0:
                self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")

            for param in self.training_params:
                if param.lr_scheduler is not None:
                    param.lr_scheduler.step()
        return loss_val_min

    def _after_train_batch(self):
        """
        Function called after the loss computation on a batch during training.
        It is responsible for stepping all optimizers.
        """
        for param in self.training_params:
            param.optimizer.step()

    def _train_epoch(self):
        torch.set_grad_enabled(True)
        for param in self.training_params:
            param.model.train()

        loss = 0.0
        data_loader = self.data_loaders["train"]
        for i, (inputs, targets) in enumerate(data_loader):
            print(
                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                end="\r",
            )

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            for param in self.training_params:
                param.optimizer.zero_grad()

            loss = loss + self._train_batch(inputs, targets)

            self._after_train_batch()
        return loss / len(data_loader)

    def _eval_epoch(self):
        torch.set_grad_enabled(False)
        for param in self.training_params:
            param.model.eval()

        loss = 0.0
        data_loader = self.data_loaders["validation"]
        for i, (inputs, targets) in enumerate(data_loader):
            print(
                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                end="\r",
            )

            inputs = inputs.to(self.device)
            targets = targets.to(self.device)

            loss = loss + self._eval_batch(inputs, targets)
        return loss / len(data_loader)

    def store_snapshot(self, prefix: str):
        for param in self.training_params:
            snapshot_file = os.path.join(
                self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
            )
            self.logger.debug(f"Store snapshot at {snapshot_file}")
            torch.save(param.model.state_dict(), snapshot_file)

    def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
        return 0

    def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
        return 0