diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py index 7ffa7ffaecac4d625a91e2625fded0e2aaa09566..9cfbfc6cbdf31c203ae30e595be956b9c41741c9 100644 --- a/mu_map/dataset/normalization.py +++ b/mu_map/dataset/normalization.py @@ -10,9 +10,14 @@ def norm_max(tensor: Tensor) -> Tensor: class MaxNormTransform(Transform): + def __init__(self, max_vals: Tuple[float, float] = None): + self.max_vals = max_vals + def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: + if self.max_vals: + return inputs / self.max_vals[0], outputs_expected / self.max_vals[1] return norm_max(inputs), outputs_expected