Newer
Older
"""
Module containing normalization methods either as functions

Tamino Huxohl
committed
from typing import Any, Callable, Optional, Tuple

Tamino Huxohl
committed
from torch import Tensor
from mu_map.dataset.transform import Transform
def norm_max(tensor: Tensor) -> Tensor:
"""
Perform maximum normalization on a tensor.
This means that the tensor is linearly normalized into the
value range [0, 1].
"""

Tamino Huxohl
committed
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
def norm_mean(tensor: Tensor) -> Transform:
This means that the tensor is divided by its mean.
"""
return tensor / tensor.mean()

Tamino Huxohl
committed
def norm_gaussian(tensor: Tensor) -> Transform:
"""
Perform Gaussian normalization on a tensor.

Tamino Huxohl
committed
This means the tensor is normalized in a way that its values
are distributed like a normal distribution with mean 0 and
standard deviation 1.
"""
return (tensor - tensor.mean()) / tensor.std()

Tamino Huxohl
committed
class NormTransform(Transform):
"""
Abstract class for all normalization transformers.

Tamino Huxohl
committed
Note that a normalization is only applied to the first tensor
input into the __call__ function. This is because usually only
the input of a neural network is normalized while the output
stays as it is. But both are returned by a dataset and other
transformers such as cropping or padding need to be applied.
"""
def __init__(self, norm_func: Callable[Tensor, Tensor]):
"""
Create a new normalization transformer.
Parameters
----------
norm_func: Callable[Tensor, Tensor]
the normalization function applied
"""
super().__init__()
self.norm_func = norm_func

Tamino Huxohl
committed
def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
"""
Normalize the first input tensor. All others remain as they
are.
"""
return (self.norm_func(tensors[0]), *tensors[1:])

Tamino Huxohl
committed

Tamino Huxohl
committed
def __eq__(self, other: Any) -> bool:
"""
Implementation of the comparison operator.
This implementation just checks that self and other are of
the same class.
Parameters
----------
other: Any
Returns
-------
bool
"""
return self.__class__ == other.__class__

Tamino Huxohl
committed
class MaxNormTransform(NormTransform):
"""
Maximum normalization as a transformer.
"""

Tamino Huxohl
committed
def __init__(self):
super().__init__(norm_max)
class MeanNormTransform(NormTransform):
"""
Mean normalization as a transformer.
"""
def __init__(self):
super().__init__(norm_mean)
class GaussianNormTransform(NormTransform):
"""
Gaussian normalization as a transformer.
"""
def __init__(self):
super().__init__(norm_gaussian)
"""
Strings defining all normalization methods which can be used
to initialize choices for CLIs.
"""
norm_choices = ["max", "mean", "gaussian"]
def norm_by_str(norm: Optional[str]) -> Optional[NormTransform]:
"""
Get a normalization transformer by a string.
This is useful for command line interfaces.
Parameters
----------
norm: str, optional
a string defining the normalization transformer (see `norm_choices`)
Returns
-------
NormTransform
a normalization transform or none if the input is none
"""
if norm is None:
return None
if "mean" in norm:
return MeanNormTransform()
elif "max" in norm:
return MaxNormTransform()
elif "gaussian" in norm:
return GaussianNormTransform()
raise ValueError(f"Unknown normalization {norm}")

Tamino Huxohl
committed
__all__ = [
norm_max.__name__,
norm_mean.__name__,
norm_gaussian.__name__,
MaxNormTransform.__name__,
MeanNormTransform.__name__,
GaussianNormTransform.__name__,
]