Newer
Older
"""
Module containing normalization methods either as functions
or transformers.
"""

Tamino Huxohl
committed
from typing import Any, Callable, Optional, Tuple

Tamino Huxohl
committed
from torch import Tensor
from mu_map.dataset.transform import Transform
def norm_max(tensor: Tensor) -> Tensor:
"""
Perform maximum normalization on a tensor.
This means that the tensor is linearly normalized into the
value range [0, 1].
"""

Tamino Huxohl
committed
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
def norm_mean(tensor: Tensor):
"""
Perform mean normalization on a tensor.
This means that the tensor is divided by its mean.
"""
return tensor / tensor.mean()

Tamino Huxohl
committed
def norm_gaussian(tensor: Tensor):
"""
Perform Gaussian normalization on a tensor.

Tamino Huxohl
committed
This means the tensor is normalized in a way that its values
are distributed like a normal distribution with mean 0 and
standard deviation 1.
"""
return (tensor - tensor.mean()) / tensor.std()

Tamino Huxohl
committed
class NormTransform(Transform):
"""
Abstract class for all normalization transformers.

Tamino Huxohl
committed
Note that a normalization is only applied to the first tensor
input into the __call__ function. This is because usually only
the input of a neural network is normalized while the output
stays as it is. But both are returned by a dataset and other
transformers such as cropping or padding need to be applied.
"""
def __init__(self, norm_func: Callable[Tensor, Tensor]):
"""
Create a new normalization transformer.
Parameters
----------
norm_func: Callable[Tensor, Tensor]
the normalization function applied
"""
super().__init__()
self.norm_func = norm_func

Tamino Huxohl
committed
def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
"""
Normalize the first input tensor. All others remain as they
are.
"""
return (self.norm_func(tensors[0]), *tensors[1:])

Tamino Huxohl
committed

Tamino Huxohl
committed
def __eq__(self, other: Any) -> bool:
"""
Implementation of the comparison operator.
This implementation just checks that self and other are of
the same class.
Parameters
----------
other: Any
Returns
-------
bool
"""
return self.__class__ == other.__class__

Tamino Huxohl
committed
class MaxNormTransform(NormTransform):
"""
Maximum normalization as a transformer.
"""

Tamino Huxohl
committed
def __init__(self):
super().__init__(norm_max)
class MeanNormTransform(NormTransform):
"""
Mean normalization as a transformer.
"""
def __init__(self):
super().__init__(norm_mean)
class GaussianNormTransform(NormTransform):
"""
Gaussian normalization as a transformer.
"""
def __init__(self):
super().__init__(norm_gaussian)
"""
Strings defining all normalization methods which can be used
to initialize choices for CLIs.
"""
norm_choices = ["max", "mean", "gaussian"]
def norm_by_str(norm: Optional[str]) -> Optional[NormTransform]:
"""
Get a normalization transformer by a string.
This is useful for command line interfaces.
Parameters
----------
norm: str, optional
a string defining the normalization transformer (see `norm_choices`)
Returns
-------
NormTransform
a normalization transform or none if the input is none
"""
if norm is None:
return None
if "mean" in norm:
return MeanNormTransform()
elif "max" in norm:
return MaxNormTransform()
elif "gaussian" in norm:
return GaussianNormTransform()
raise ValueError(f"Unknown normalization {norm}")

Tamino Huxohl
committed
__all__ = [
norm_max.__name__,
norm_mean.__name__,
norm_gaussian.__name__,
MaxNormTransform.__name__,
MeanNormTransform.__name__,
GaussianNormTransform.__name__,
]