from typing import Any, List

import torch
import torch.nn as nn


class GradientDifferenceLoss(nn.Module):
    """
    Gradient Difference Loss (GDL) inspired by https://github.com/mmany/pytorch-GDL/blob/main/custom_loss_functions.py.
    It is modified to deal with 5D tensors (batch_size, channels, z, y, x).
    """

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
        gradient_diff_z = (inputs.diff(dim=2) - targets.diff(axis=2)).pow(2).sum()
        gradient_diff_y = (inputs.diff(dim=3) - targets.diff(axis=3)).pow(2).sum()
        gradient_diff_x = (inputs.diff(dim=4) - targets.diff(axis=4)).pow(2).sum()

        gradient_diff = gradient_diff_x + gradient_diff_y + gradient_diff_z
        return gradient_diff / inputs.numel()


def loss_by_string(loss_str: str) -> nn.Module:
    """
    Retrieve a loss function defined by a string.
    E.g., L1 returns the torch module of the l1 loss function.

    :param loss_str: loss function defined as a string
    :returns: an executable loss function
    """
    loss_str = loss_str.lower()
    if "l1" in loss_str:
        return nn.L1Loss(reduction="mean")
    elif "l2" in loss_str or "mse" in loss_str:
        return nn.MSELoss(reduction="mean")
    elif "gdl" in loss_str or "gradientdifferenceloss" in loss_str:
        return GradientDifferenceLoss()
    else:
        raise ValueError(f"Unknown loss function: {loss_str}")


class WeightedLoss(nn.Module):
    """
    Definition of a weighted loss consisting of a number of losses
    with according weights.

    :param losses: the losses to be summed and weighted
    :param weights: weights for each loss function
    """

    def __init__(self, losses: List[nn.Module], weights: List[float]):
        super().__init__()

        assert len(losses) == len(
            weights
        ), f"There is a different number of losses {len(losses)} compared to weights {len(weights)}"

        self.losses = losses
        self.weights = weights

    def forward(self, outputs: torch.Tensor, targets: torch.Tensor):
        loss = 0.0
        for loss_func, weight in zip(self.losses, self.weights):
            loss += weight * loss_func(outputs, targets)
        return loss

    def __repr__(self):
        return " + ".join(
            map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
        )

    def __eq__(self, other: Any) -> bool:
        """
        Implementation of the comparison operator.

        This implementation makes sure that both weighted losses consist
        of the same loss classes with the same weights. Note that this
        is not a `smart` implementation as the order matters. For example,
        `L2+GDL` is not equal to `GDL+L2`.

        Parameters
        ----------
        other: Any

        Returns
        -------
        bool
        """
        if self.__class__ != other.__class__:
            return False

        if len(self.losses) != len(other.losses):
            return False

        loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses))
        loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses))
        return loss_classes_self == loss_classes_other and self.weights == other.weights

    @classmethod
    def from_str(cls, loss_func_str: str):
        """
        Parse a weighted loss function from a string.
        E.g.: 0.1*gdl+0.9*l2
        """
        addends = loss_func_str.split("+")

        losses, weights = [], []
        for addend in loss_func_str.split("+"):
            factors = addend.strip().split("*")

            if len(factors) == 1:
                weights.append(1.0)
                losses.append(loss_by_string(factors[0]))
            else:
                weights.append(float(factors[0]))
                losses.append(loss_by_string(factors[1]))

        return cls(losses, weights)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Test building a loss function from a string"
    )
    parser.add_argument(
        "loss", type=str, help="description of a loss function, e.g., 0.75*L1+0.25*GDL"
    )
    args = parser.parse_args()

    criterion = WeightedLoss.from_str(args.loss)

    inputs = torch.Tensor([[1, 2], [3, 4]])
    inputs = inputs.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
    targets = torch.ones(inputs.shape)

    loss = criterion(inputs, targets)

    print(f"Inputs:\n{inputs}")
    print(f"Targets:\n{targets}")
    print(f"Criterion: {criterion}")
    print(f"Loss: {loss.item():.6f}")