from logging import Logger
from typing import Optional

import torch

from mu_map.dataset.default import MuMapDataset
from mu_map.training.lib import TrainingParams, AbstractTraining
from mu_map.training.loss import WeightedLoss


# Establish convention for real and fake labels during training
LABEL_REAL = 1.0
LABEL_FAKE = 0.0


class DiscriminatorParams(TrainingParams):
    """
    Wrap training parameters to always carry the name 'Discriminator'.
    """
    def __init__(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
    ):
        super().__init__(
            name="Discriminator",
            model=model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
        )

class GeneratorParams(TrainingParams):
    """
    Wrap training parameters to always carry the name 'Generator'.
    """
    def __init__(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
    ):
        super().__init__(
            name="Generator",
            model=model,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
        )



class cGANTraining(AbstractTraining):
    """
    Implementation of a conditional generative adversarial network training.
    """
    def __init__(
        self,
        epochs: int,
        dataset: MuMapDataset,
        batch_size: int,
        device: torch.device,
        snapshot_dir: str,
        snapshot_epoch: int,
        params_generator: GeneratorParams,
        params_discriminator: DiscriminatorParams,
        loss_func_dist: WeightedLoss,
        weight_criterion_dist: float,
        weight_criterion_adv: float,
        logger: Optional[Logger] = None,
    ):
        super().__init__(
            epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger
        )
        self.training_params.append(params_generator)
        self.training_params.append(params_discriminator)

        self.generator = params_generator.model
        self.discriminator = params_discriminator.model

        self.optim_g = params_generator.optimizer
        self.optim_d = params_discriminator.optimizer

        self.weight_criterion_dist = weight_criterion_dist
        self.weight_criterion_adv = weight_criterion_adv

        self.criterion_adv = torch.nn.MSELoss(reduction="mean")
        self.criterion_dist = loss_func_dist

    def _after_train_batch(self):
        """
        Overwrite calling step on all optimizers as this needs to be done
        separately for the generator and discriminator during the training of
        a batch.
        """
        pass

    def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
        mu_maps_real = mu_maps # rename real mu maps for clarification
        # compute fake mu maps with generator
        mu_maps_fake = self.generator(recons)

        # note: the batch size may differ for the last batch which is why self.batch_size is not reliable
        batch_size = recons.shape[0]
        labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device)
        labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device)

        # prepare inputs for the discriminator
        inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
        inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)

        # ======================= Discriminator =====================================
        # compute discriminator loss for fake mu maps
        # detach is called so that gradients are not computed for the generator
        outputs_d_fake = self.discriminator(inputs_d_fake.detach())
        loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake)

        # compute discriminator loss for real mu maps
        outputs_d_real = self.discriminator(inputs_d_real)
        loss_d_real = self.criterion_adv(outputs_d_real, labels_real)

        # update discriminator
        loss_d = 0.5 * (loss_d_fake + loss_d_real)
        loss_d.backward()  # compute gradients
        self.optim_d.step()
        # ===========================================================================

        # ======================= Generator =========================================
        outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach
        loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
        loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
        loss_g = (
            self.weight_criterion_adv * loss_g_adv
            + self.weight_criterion_dist * loss_g_dist
        )
        loss_g.backward()
        self.optim_g.step()
        # ===========================================================================

        return loss_g.item()

    def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
        mu_maps_fake = self.generator(recons)
        loss = torch.nn.functional.l1_loss(mu_maps_fake, mu_maps)
        return loss.item()


if __name__ == "__main__":
    import argparse
    import os
    import random
    import sys

    import numpy as np

    from mu_map.dataset.patches import MuMapPatchDataset
    from mu_map.dataset.normalization import (
        MeanNormTransform,
        MaxNormTransform,
        GaussianNormTransform,
    )
    from mu_map.dataset.transform import PadCropTranform, SequenceTransform
    from mu_map.logging import add_logging_args, get_logger_by_args
    from mu_map.models.unet import UNet
    from mu_map.models.discriminator import Discriminator, PatchDiscriminator

    parser = argparse.ArgumentParser(
        description="Train a UNet model to predict μ-maps from reconstructed images",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Model Args
    parser.add_argument(
        "--features",
        type=int,
        nargs="+",
        default=[64, 128, 256, 512],
        help="number of features in the layers of the UNet structure",
    )

    # Dataset Args
    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="data/second/",
        help="the directory where the dataset for training is found",
    )
    parser.add_argument(
        "--input_norm",
        type=str,
        choices=["none", "mean", "max", "gaussian"],
        default="mean",
        help="type of normalization applied to the reconstructions",
    )
    parser.add_argument(
        "--patch_size",
        type=int,
        default=32,
        help="the size of patches extracted for each reconstruction",
    )
    parser.add_argument(
        "--patch_offset",
        type=int,
        default=20,
        help="offset to ignore the border of the image",
    )
    parser.add_argument(
        "--number_of_patches",
        type=int,
        default=100,
        help="number of patches extracted for each image",
    )
    parser.add_argument(
        "--no_shuffle",
        action="store_true",
        help="do not shuffle patches in the dataset",
    )
    parser.add_argument(
        "--scatter_correction",
        action="store_true",
        help="use the scatter corrected reconstructions in the dataset",
    )

    # Training Args
    parser.add_argument(
        "--seed",
        type=int,
        help="seed used for random number generation",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=64,
        help="the batch size used for training",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="train_data",
        help="directory in which results (snapshots and logs) of this training are saved",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="the number of epochs for which the model is trained",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:0" if torch.cuda.is_available() else "cpu",
        help="the device (cpu or gpu) with which the training is performed",
    )
    parser.add_argument(
        "--dist_loss_func",
        type=str,
        default="l1",
        help="define the loss function used as the distance loss of the generator , e.g. 0.75*l2+0.25*gdl",
    )
    parser.add_argument(
        "--dist_loss_weight",
        type=float,
        default=100.0,
        help="weight for the distance loss of the generator",
    )
    parser.add_argument(
        "--adv_loss_weight",
        type=float,
        default=1.0,
        help="weight for the Adversarial-Loss of the generator",
    )
    parser.add_argument(
        "--lr", type=float, default=0.001, help="the initial learning rate for training"
    )
    parser.add_argument(
        "--decay_lr",
        action="store_true",
        help="decay the learning rate",
    )
    parser.add_argument(
        "--lr_decay_factor",
        type=float,
        default=0.99,
        help="decay factor for the learning rate",
    )
    parser.add_argument(
        "--lr_decay_epoch",
        type=int,
        default=1,
        help="frequency in epochs at which the learning rate is decayed",
    )
    parser.add_argument(
        "--snapshot_dir",
        type=str,
        default="snapshots",
        help="directory under --output_dir where snapshots are stored",
    )
    parser.add_argument(
        "--snapshot_epoch",
        type=int,
        default=10,
        help="frequency in epochs at which snapshots are stored",
    )

    # Logging Args
    add_logging_args(parser, defaults={"--logfile": "train.log"})

    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
    if not os.path.exists(args.snapshot_dir):
        os.mkdir(args.snapshot_dir)
    else:
        if len(os.listdir(args.snapshot_dir)) > 0:
            print(
                f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
            )
            print(f"           Exit so that data is not accidentally overwritten!")
            exit(1)

    args.logfile = os.path.join(args.output_dir, args.logfile)

    logger = get_logger_by_args(args)
    logger.info(args)

    device = torch.device(args.device)

    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
    logger.info(f"Seed: {args.seed}")
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    transform_normalization = None
    if args.input_norm == "mean":
        transform_normalization = MeanNormTransform()
    elif args.input_norm == "max":
        transform_normalization = MaxNormTransform()
    elif args.input_norm == "gaussian":
        transform_normalization = GaussianNormTransform()
    transform_normalization = SequenceTransform(
        [transform_normalization, PadCropTranform(dim=3, size=32)]
    )

    dataset = MuMapPatchDataset(
        args.dataset_dir,
        patches_per_image=args.number_of_patches,
        patch_size=args.patch_size,
        patch_offset=args.patch_offset,
        shuffle=not args.no_shuffle,
        transform_normalization=transform_normalization,
        scatter_correction=args.scatter_correction,
        logger=logger,
    )

    discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
    discriminator = discriminator.to(device)
    optimizer = torch.optim.Adam(
        discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999)
    )
    lr_scheduler = (
        torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.lr_decay_epoch, gamma=args.lr_decay_factor
        )
        if args.decay_lr
        else None
    )
    params_d = DiscriminatorParams(
        model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
    )

    generator = UNet(in_channels=1, features=args.features)
    generator = generator.to(device)
    optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
    lr_scheduler = (
        torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
        )
        if args.decay_lr
        else None
    )
    params_g = GeneratorParams(
        model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
    )

    dist_criterion = WeightedLoss.from_str(args.dist_loss_func)
    logger.debug(f"Use distance criterion: {dist_criterion}")

    training = cGANTraining(
        epochs=args.epochs,
        dataset=dataset,
        batch_size=args.batch_size,
        device=device,
        snapshot_dir=args.snapshot_dir,
        snapshot_epoch=args.snapshot_epoch,
        params_generator=params_g,
        params_discriminator=params_d,
        loss_func_dist=dist_criterion,
        weight_criterion_dist=args.dist_loss_weight,
        weight_criterion_adv=args.adv_loss_weight,
        logger=logger,
    )
    training.run()