from torch import Tensor from mu_map.dataset.transform import Transform def norm_max(tensor: Tensor) -> Tensor: return (tensor - tensor.min()) / (tensor.max() - tensor.min()) class MaxNormTransform(Transform): def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]: return norm_max(inputs), outputs_expected def norm_mean(tensor: Tensor): return tensor / tensor.mean() class MeanNormTransform(Transform): def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]: return norm_mean(inputs), outputs_expected def norm_gaussian(tensor: Tensor): return (tensor - tensor.mean()) / tensor.std() class GaussianNormTransform(Transform): def __call__(self, inputs: Tensor, outputs_expected: Tensor) -> Tuple[Tensor, Tensor]: return norm_gaussian(inputs), outputs_expected __all__ = [ norm_max.__name__, norm_mean.__name__, norm_gaussian.__name__, MaxNormTransform.__name__, MeanNormTransform.__name__, GaussianNormTransform.__name__, ]