Skip to content
Snippets Groups Projects
normalization.py 1.31 KiB
from typing import Tuple

from torch import Tensor

from mu_map.dataset.transform import Transform


def norm_max(tensor: Tensor) -> Tensor:
    return (tensor - tensor.min()) / (tensor.max() - tensor.min())


class MaxNormTransform(Transform):
    def __init__(self, max_vals: Tuple[float, float] = None):
        self.max_vals = max_vals

    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        if self.max_vals:
            return inputs / self.max_vals[0], outputs_expected / self.max_vals[1]
        return norm_max(inputs), outputs_expected


def norm_mean(tensor: Tensor):
    return tensor / tensor.mean()


class MeanNormTransform(Transform):
    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        return norm_mean(inputs), outputs_expected


def norm_gaussian(tensor: Tensor):
    return (tensor - tensor.mean()) / tensor.std()


class GaussianNormTransform(Transform):
    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        return norm_gaussian(inputs), outputs_expected


__all__ = [
    norm_max.__name__,
    norm_mean.__name__,
    norm_gaussian.__name__,
    MaxNormTransform.__name__,
    MeanNormTransform.__name__,
    GaussianNormTransform.__name__,
]