""" Module containing normalization methods either as functions or transforms. """ from typing import Any, Callable, Optional, Tuple from torch import Tensor from mu_map.dataset.transform import Transform def norm_max(tensor: Tensor) -> Tensor: """ Perform maximum normalization on a tensor. This means that the tensor is linearly normalized into the value range [0, 1]. """ return (tensor - tensor.min()) / (tensor.max() - tensor.min()) def norm_mean(tensor: Tensor) -> Transform: """ Perform mean normalization on a tensor. This means that the tensor is divided by its mean. """ return tensor / tensor.mean() def norm_gaussian(tensor: Tensor) -> Transform: """ Perform Gaussian normalization on a tensor. This means the tensor is normalized in a way that its values are distributed like a normal distribution with mean 0 and standard deviation 1. """ return (tensor - tensor.mean()) / tensor.std() class NormTransform(Transform): """ Abstract class for all normalization transformers. Note that a normalization is only applied to the first tensor input into the __call__ function. This is because usually only the input of a neural network is normalized while the output stays as it is. But both are returned by a dataset and other transformers such as cropping or padding need to be applied. """ def __init__(self, norm_func: Callable[Tensor, Tensor]): """ Create a new normalization transformer. Parameters ---------- norm_func: Callable[Tensor, Tensor] the normalization function applied """ super().__init__() self.norm_func = norm_func def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: """ Normalize the first input tensor. All others remain as they are. """ return (self.norm_func(tensors[0]), *tensors[1:]) def __eq__(self, other: Any) -> bool: """ Implementation of the comparison operator. This implementation just checks that self and other are of the same class. Parameters ---------- other: Any Returns ------- bool """ return self.__class__ == other.__class__ class MaxNormTransform(NormTransform): """ Maximum normalization as a transformer. """ def __init__(self): super().__init__(norm_max) class MeanNormTransform(NormTransform): """ Mean normalization as a transformer. """ def __init__(self): super().__init__(norm_mean) class GaussianNormTransform(NormTransform): """ Gaussian normalization as a transformer. """ def __init__(self): super().__init__(norm_gaussian) """ Strings defining all normalization methods which can be used to initialize choices for CLIs. """ norm_choices = ["max", "mean", "gaussian"] def norm_by_str(norm: Optional[str]) -> Optional[NormTransform]: """ Get a normalization transformer by a string. This is useful for command line interfaces. Parameters ---------- norm: str, optional a string defining the normalization transformer (see `norm_choices`) Returns ------- NormTransform a normalization transform or none if the input is none """ if norm is None: return None norm = norm.lower() if "mean" in norm: return MeanNormTransform() elif "max" in norm: return MaxNormTransform() elif "gaussian" in norm: return GaussianNormTransform() raise ValueError(f"Unknown normalization {norm}") __all__ = [ norm_max.__name__, norm_mean.__name__, norm_gaussian.__name__, MaxNormTransform.__name__, MeanNormTransform.__name__, GaussianNormTransform.__name__, ]