From 63680f4b0edf22721ce6b6a7119a3297782ff8c9 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 27 Sep 2022 13:34:58 +0200 Subject: [PATCH] max normalization allows to define norm values --- mu_map/dataset/normalization.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py index 7ffa7ff..9cfbfc6 100644 --- a/mu_map/dataset/normalization.py +++ b/mu_map/dataset/normalization.py @@ -10,9 +10,14 @@ def norm_max(tensor: Tensor) -> Tensor: class MaxNormTransform(Transform): + def __init__(self, max_vals: Tuple[float, float] = None): + self.max_vals = max_vals + def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: + if self.max_vals: + return inputs / self.max_vals[0], outputs_expected / self.max_vals[1] return norm_max(inputs), outputs_expected -- GitLab