from typing import Tuple

from torch import Tensor

from mu_map.dataset.transform import Transform


def norm_max(tensor: Tensor) -> Tensor:
    return (tensor - tensor.min()) / (tensor.max() - tensor.min())


class MaxNormTransform(Transform):
    def __init__(self, max_vals: Tuple[float, float] = None):
        self.max_vals = max_vals

    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        if self.max_vals:
            return inputs / self.max_vals[0], outputs_expected / self.max_vals[1]
        return norm_max(inputs), outputs_expected


def norm_mean(tensor: Tensor):
    return tensor / tensor.mean()


class MeanNormTransform(Transform):
    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        return norm_mean(inputs), outputs_expected


def norm_gaussian(tensor: Tensor):
    return (tensor - tensor.mean()) / tensor.std()


class GaussianNormTransform(Transform):
    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        return norm_gaussian(inputs), outputs_expected


norm_choices = ["max", "mean", "gaussian"]


def norm_by_str(norm: str):
    if norm is None:
        return None

    if norm == "mean":
        return MeanNormTransform()
    elif norm == "max":
        return MaxNormTransform()
    elif norm == "gaussian":
        return GaussianNormTransform()

    raise ValueError(f"Unknown normalization {norm}")


__all__ = [
    norm_max.__name__,
    norm_mean.__name__,
    norm_gaussian.__name__,
    MaxNormTransform.__name__,
    MeanNormTransform.__name__,
    GaussianNormTransform.__name__,
]