import torch def norm_max(tensor: torch.Tensor): return (tensor - tensor.min()) / (tensor.max() - tensor.min()) class MaxNorm: def __call__(self, tensor: torch.Tensor): return norm_max(tensor) def norm_mean(tensor: torch.Tensor): return tensor / tensor.mean() class MeanNorm: def __call__(self, tensor: torch.Tensor): return norm_mean(tensor) def norm_gaussian(tensor: torch.Tensor): return (tensor - tensor.mean()) / tensor.std() class GaussianNorm: def __call__(self, tensor: torch.Tensor): return norm_gaussian(tensor) __all__ = [ norm_max.__name__, norm_mean.__name__, norm_gaussian.__name__, MaxNorm.__name__, MeanNorm.__name__, GaussianNorm.__name__, ]