Skip to content
Snippets Groups Projects
normalization.py 1.05 KiB
from torch import Tensor

from mu_map.dataset.transform import Transform


def norm_max(tensor: Tensor) -> Tensor:
    return (tensor - tensor.min()) / (tensor.max() - tensor.min())


class MaxNormTransform(Transform):
    def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
        return norm_max(inputs), outputs_expected


def norm_mean(tensor: Tensor):
    return tensor / tensor.mean()


class MeanNormTransform(Transform):
    def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
        return norm_mean(inputs), outputs_expected


def norm_gaussian(tensor: Tensor):
    return (tensor - tensor.mean()) / tensor.std()


class GaussianNormTransform(Transform):
    def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
        return norm_gaussian(inputs), outputs_expected


__all__ = [
    norm_max.__name__,
    norm_mean.__name__,
    norm_gaussian.__name__,
    MaxNormTransform.__name__,
    MeanNormTransform.__name__,
    GaussianNormTransform.__name__,
]