From 63680f4b0edf22721ce6b6a7119a3297782ff8c9 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 27 Sep 2022 13:34:58 +0200
Subject: [PATCH] max normalization allows to define norm values

---
 mu_map/dataset/normalization.py | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py
index 7ffa7ff..9cfbfc6 100644
--- a/mu_map/dataset/normalization.py
+++ b/mu_map/dataset/normalization.py
@@ -10,9 +10,14 @@ def norm_max(tensor: Tensor) -> Tensor:
 
 
 class MaxNormTransform(Transform):
+    def __init__(self, max_vals: Tuple[float, float] = None):
+        self.max_vals = max_vals
+
     def __call__(
         self, inputs: Tensor, outputs_expected: Tensor
     ) -> Tuple[Tensor, Tensor]:
+        if self.max_vals:
+            return inputs / self.max_vals[0], outputs_expected / self.max_vals[1]
         return norm_max(inputs), outputs_expected
 
 
-- 
GitLab