-
Tamino Huxohl authoredTamino Huxohl authored
normalization.py 1.05 KiB
from torch import Tensor
from mu_map.dataset.transform import Transform
def norm_max(tensor: Tensor) -> Tensor:
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
class MaxNormTransform(Transform):
def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
return norm_max(inputs), outputs_expected
def norm_mean(tensor: Tensor):
return tensor / tensor.mean()
class MeanNormTransform(Transform):
def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
return norm_mean(inputs), outputs_expected
def norm_gaussian(tensor: Tensor):
return (tensor - tensor.mean()) / tensor.std()
class GaussianNormTransform(Transform):
def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]:
return norm_gaussian(inputs), outputs_expected
__all__ = [
norm_max.__name__,
norm_mean.__name__,
norm_gaussian.__name__,
MaxNormTransform.__name__,
MeanNormTransform.__name__,
GaussianNormTransform.__name__,
]