Skip to content
Snippets Groups Projects
normalization.py 1.31 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    from typing import Tuple
    
    
    from torch import Tensor
    
    from mu_map.dataset.transform import Transform
    
    
    def norm_max(tensor: Tensor) -> Tensor:
        return (tensor - tensor.min()) / (tensor.max() - tensor.min())
    
    
    class MaxNormTransform(Transform):
    
        def __init__(self, max_vals: Tuple[float, float] = None):
            self.max_vals = max_vals
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        def __call__(
            self, inputs: Tensor, outputs_expected: Tensor
        ) -> Tuple[Tensor, Tensor]:
    
            if self.max_vals:
                return inputs / self.max_vals[0], outputs_expected / self.max_vals[1]
    
            return norm_max(inputs), outputs_expected
    
    
    def norm_mean(tensor: Tensor):
        return tensor / tensor.mean()
    
    
    class MeanNormTransform(Transform):
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        def __call__(
            self, inputs: Tensor, outputs_expected: Tensor
        ) -> Tuple[Tensor, Tensor]:
    
            return norm_mean(inputs), outputs_expected
    
    
    def norm_gaussian(tensor: Tensor):
        return (tensor - tensor.mean()) / tensor.std()
    
    
    class GaussianNormTransform(Transform):
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        def __call__(
            self, inputs: Tensor, outputs_expected: Tensor
        ) -> Tuple[Tensor, Tensor]:
    
            return norm_gaussian(inputs), outputs_expected
    
    
    __all__ = [
        norm_max.__name__,
        norm_mean.__name__,
        norm_gaussian.__name__,
        MaxNormTransform.__name__,
        MeanNormTransform.__name__,
        GaussianNormTransform.__name__,
    ]