from typing import Tuple from torch import Tensor from mu_map.dataset.transform import Transform def norm_max(tensor: Tensor) -> Tensor: return (tensor - tensor.min()) / (tensor.max() - tensor.min()) class MaxNormTransform(Transform): def __init__(self, max_vals: Tuple[float, float] = None): self.max_vals = max_vals def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: if self.max_vals: return inputs / self.max_vals[0], outputs_expected / self.max_vals[1] return norm_max(inputs), outputs_expected def norm_mean(tensor: Tensor): return tensor / tensor.mean() class MeanNormTransform(Transform): def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: return norm_mean(inputs), outputs_expected def norm_gaussian(tensor: Tensor): return (tensor - tensor.mean()) / tensor.std() class GaussianNormTransform(Transform): def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: return norm_gaussian(inputs), outputs_expected norm_choices = ["max", "mean", "gaussian"] def norm_by_str(norm: str): if norm is None: return None if norm == "mean": return MeanNormTransform() elif norm == "max": return MaxNormTransform() elif norm == "gaussian": return GaussianNormTransform() raise ValueError(f"Unknown normalization {norm}") __all__ = [ norm_max.__name__, norm_mean.__name__, norm_gaussian.__name__, MaxNormTransform.__name__, MeanNormTransform.__name__, GaussianNormTransform.__name__, ]