"""
Module containing normalization methods either as functions
or transforms.
"""
from typing import Any, Callable, Optional, Tuple

from torch import Tensor

from mu_map.dataset.transform import Transform


def norm_max(tensor: Tensor) -> Tensor:
    """
    Perform maximum normalization on a tensor.

    This means that the tensor is linearly normalized into the
    value range [0, 1].
    """
    return (tensor - tensor.min()) / (tensor.max() - tensor.min())


def norm_mean(tensor: Tensor) -> Transform:
    """
    Perform mean normalization on a tensor.

    This means that the tensor is divided by its mean.
    """
    return tensor / tensor.mean()


def norm_gaussian(tensor: Tensor) -> Transform:
    """
    Perform Gaussian normalization on a tensor.

    This means the tensor is normalized in a way that its values
    are distributed like a normal distribution with mean 0 and
    standard deviation 1.
    """
    return (tensor - tensor.mean()) / tensor.std()


class NormTransform(Transform):
    """
    Abstract class for all normalization transformers.

    Note that a normalization is only applied to the first tensor
    input into the __call__ function. This is because usually only
    the input of a neural network is normalized while the output
    stays as it is. But both are returned by a dataset and other
    transformers such as cropping or padding need to be applied.
    """

    def __init__(self, norm_func: Callable[Tensor, Tensor]):
        """
        Create a new normalization transformer.

        Parameters
        ----------
        norm_func: Callable[Tensor, Tensor]
            the normalization function applied
        """
        super().__init__()
        self.norm_func = norm_func

    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
        """
        Normalize the first input tensor. All others remain as they
        are.
        """
        return (self.norm_func(tensors[0]), *tensors[1:])

    def __eq__(self, other: Any) -> bool:
        """
        Implementation of the comparison operator.

        This implementation just checks that self and other are of
        the same class.

        Parameters
        ----------
        other: Any

        Returns
        -------
        bool
        """
        return self.__class__ == other.__class__


class MaxNormTransform(NormTransform):
    """
    Maximum normalization as a transformer.
    """

    def __init__(self):
        super().__init__(norm_max)


class MeanNormTransform(NormTransform):
    """
    Mean normalization as a transformer.
    """

    def __init__(self):
        super().__init__(norm_mean)


class GaussianNormTransform(NormTransform):
    """
    Gaussian normalization as a transformer.
    """

    def __init__(self):
        super().__init__(norm_gaussian)


"""
Strings defining all normalization methods which can be used
to initialize choices for CLIs.
"""
norm_choices = ["max", "mean", "gaussian"]


def norm_by_str(norm: Optional[str]) -> Optional[NormTransform]:
    """
    Get a normalization transformer by a string.

    This is useful for command line interfaces.

    Parameters
    ----------
    norm: str, optional
        a string defining the normalization transformer (see `norm_choices`)

    Returns
    -------
    NormTransform
        a normalization transform or none if the input is none
    """
    if norm is None:
        return None

    norm = norm.lower()
    if "mean" in norm:
        return MeanNormTransform()
    elif "max" in norm:
        return MaxNormTransform()
    elif "gaussian" in norm:
        return GaussianNormTransform()

    raise ValueError(f"Unknown normalization {norm}")


__all__ = [
    norm_max.__name__,
    norm_mean.__name__,
    norm_gaussian.__name__,
    MaxNormTransform.__name__,
    MeanNormTransform.__name__,
    GaussianNormTransform.__name__,
]