Skip to content
Snippets Groups Projects
normalization.py 3.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • """
    Module containing normalization methods either as functions
    
    or transforms.
    
    """
    
    from typing import Any, Callable, Optional, Tuple
    
    from torch import Tensor
    
    from mu_map.dataset.transform import Transform
    
    
    def norm_max(tensor: Tensor) -> Tensor:
    
        """
        Perform maximum normalization on a tensor.
    
        This means that the tensor is linearly normalized into the
        value range [0, 1].
        """
    
        return (tensor - tensor.min()) / (tensor.max() - tensor.min())
    
    
    
    def norm_mean(tensor: Tensor) -> Transform:
    
        """
        Perform mean normalization on a tensor.
    
        This means that the tensor is divided by its mean.
        """
        return tensor / tensor.mean()
    
    def norm_gaussian(tensor: Tensor) -> Transform:
    
        """
        Perform Gaussian normalization on a tensor.
    
        This means the tensor is normalized in a way that its values
        are distributed like a normal distribution with mean 0 and
        standard deviation 1.
        """
        return (tensor - tensor.mean()) / tensor.std()
    
    class NormTransform(Transform):
        """
        Abstract class for all normalization transformers.
    
        Note that a normalization is only applied to the first tensor
        input into the __call__ function. This is because usually only
        the input of a neural network is normalized while the output
        stays as it is. But both are returned by a dataset and other
        transformers such as cropping or padding need to be applied.
        """
    
        def __init__(self, norm_func: Callable[Tensor, Tensor]):
            """
            Create a new normalization transformer.
    
            Parameters
            ----------
            norm_func: Callable[Tensor, Tensor]
                the normalization function applied
            """
            super().__init__()
            self.norm_func = norm_func
    
        def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
            """
            Normalize the first input tensor. All others remain as they
            are.
            """
            return (self.norm_func(tensors[0]), *tensors[1:])
    
        def __eq__(self, other: Any) -> bool:
            """
            Implementation of the comparison operator.
    
            This implementation just checks that self and other are of
            the same class.
    
            Parameters
            ----------
            other: Any
    
            Returns
            -------
            bool
            """
            return self.__class__ == other.__class__
    
    
    class MaxNormTransform(NormTransform):
        """
        Maximum normalization as a transformer.
        """
    
        def __init__(self):
            super().__init__(norm_max)
    
    
    class MeanNormTransform(NormTransform):
        """
        Mean normalization as a transformer.
        """
    
        def __init__(self):
            super().__init__(norm_mean)
    
    
    class GaussianNormTransform(NormTransform):
        """
        Gaussian normalization as a transformer.
        """
    
        def __init__(self):
            super().__init__(norm_gaussian)
    
    
    """
    Strings defining all normalization methods which can be used
    to initialize choices for CLIs.
    """
    
    norm_choices = ["max", "mean", "gaussian"]
    
    
    
    def norm_by_str(norm: Optional[str]) -> Optional[NormTransform]:
        """
        Get a normalization transformer by a string.
    
        This is useful for command line interfaces.
    
        Parameters
        ----------
        norm: str, optional
            a string defining the normalization transformer (see `norm_choices`)
    
        Returns
        -------
        NormTransform
            a normalization transform or none if the input is none
        """
    
        if norm is None:
            return None
    
    
        norm = norm.lower()
    
            return MeanNormTransform()
    
            return MaxNormTransform()
    
            return GaussianNormTransform()
    
        raise ValueError(f"Unknown normalization {norm}")
    
    
    
    __all__ = [
        norm_max.__name__,
        norm_mean.__name__,
        norm_gaussian.__name__,
        MaxNormTransform.__name__,
        MeanNormTransform.__name__,
        GaussianNormTransform.__name__,
    ]