diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py
index 7ffa7ffaecac4d625a91e2625fded0e2aaa09566..9cfbfc6cbdf31c203ae30e595be956b9c41741c9 100644
--- a/mu_map/dataset/normalization.py
+++ b/mu_map/dataset/normalization.py
@@ -10,9 +10,14 @@ def norm_max(tensor: Tensor) -> Tensor:
 
 
 class MaxNormTransform(Transform):
+    def __init__(self, max_vals: Tuple[float, float] = None):
+        self.max_vals = max_vals
+
     def __call__(
         self, inputs: Tensor, outputs_expected: Tensor
     ) -> Tuple[Tensor, Tensor]:
+        if self.max_vals:
+            return inputs / self.max_vals[0], outputs_expected / self.max_vals[1]
         return norm_max(inputs), outputs_expected