Skip to content
Snippets Groups Projects
Commit 56b9f617 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update normalization code

parent 9c8ba6bb
No related branches found
No related tags found
No related merge requests found
from typing import Tuple
"""
Module containing normalization methods either as functions
or transformers.
"""
from typing import Callable, Optional, Tuple
from torch import Tensor
......@@ -6,50 +10,120 @@ from mu_map.dataset.transform import Transform
def norm_max(tensor: Tensor) -> Tensor:
"""
Perform maximum normalization on a tensor.
This means that the tensor is linearly normalized into the
value range [0, 1].
"""
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
class MaxNormTransform(Transform):
def __init__(self, max_vals: Tuple[float, float] = None):
self.max_vals = max_vals
def norm_mean(tensor: Tensor):
"""
Perform mean normalization on a tensor.
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
if self.max_vals:
return inputs / self.max_vals[0], outputs_expected / self.max_vals[1]
return norm_max(inputs), outputs_expected
This means that the tensor is divided by its mean.
"""
return tensor / tensor.mean()
def norm_mean(tensor: Tensor):
return tensor / tensor.mean()
def norm_gaussian(tensor: Tensor):
"""
Perform Gaussian normalization on a tensor.
This means the tensor is normalized in a way that its values
are distributed like a normal distribution with mean 0 and
standard deviation 1.
"""
return (tensor - tensor.mean()) / tensor.std()
class MeanNormTransform(Transform):
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
return norm_mean(inputs), outputs_expected
class NormTransform(Transform):
"""
Abstract class for all normalization transformers.
def norm_gaussian(tensor: Tensor):
return (tensor - tensor.mean()) / tensor.std()
Note that a normalization is only applied to the first tensor
input into the __call__ function. This is because usually only
the input of a neural network is normalized while the output
stays as it is. But both are returned by a dataset and other
transformers such as cropping or padding need to be applied.
"""
def __init__(self, norm_func: Callable[Tensor, Tensor]):
"""
Create a new normalization transformer.
Parameters
----------
norm_func: Callable[Tensor, Tensor]
the normalization function applied
"""
super().__init__()
self.norm_func = norm_func
def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
"""
Normalize the first input tensor. All others remain as they
are.
"""
return (self.norm_func(tensors[0]), *tensors[1:])
class GaussianNormTransform(Transform):
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
return norm_gaussian(inputs), outputs_expected
class MaxNormTransform(NormTransform):
"""
Maximum normalization as a transformer.
"""
def __init__(self):
super().__init__(norm_max)
class MeanNormTransform(NormTransform):
"""
Mean normalization as a transformer.
"""
def __init__(self):
super().__init__(norm_mean)
class GaussianNormTransform(NormTransform):
"""
Gaussian normalization as a transformer.
"""
def __init__(self):
super().__init__(norm_gaussian)
"""
Strings defining all normalization methods which can be used
to initialize choices for CLIs.
"""
norm_choices = ["max", "mean", "gaussian"]
def norm_by_str(norm: str):
def norm_by_str(norm: Optional[str]) -> Optional[NormTransform]:
"""
Get a normalization transformer by a string.
This is useful for command line interfaces.
Parameters
----------
norm: str, optional
a string defining the normalization transformer (see `norm_choices`)
Returns
-------
NormTransform
a normalization transform or none if the input is none
"""
if norm is None:
return None
norm = norm.lower()
if norm == "mean":
return MeanNormTransform()
elif norm == "max":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment