From e95e5b6f7c7182de13d88f75f74199445ac808b0 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 4 Oct 2022 09:44:03 +0200
Subject: [PATCH] implement a scale tranform

---
 mu_map/dataset/transform.py | 24 ++++++++++++++++++++++++
 1 file changed, 24 insertions(+)

diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py
index 7b2685e..a991398 100644
--- a/mu_map/dataset/transform.py
+++ b/mu_map/dataset/transform.py
@@ -34,3 +34,27 @@ class SequenceTransform(Transform):
         for transforms in self.transforms:
             inputs, outputs_expected = transforms(inputs, outputs_expected)
         return inputs, outputs_expected
+
+
+class ScaleTransform(Transform):
+    """
+    A transformer that scales the inputs and outputs by pre-defined factors.
+    """
+
+    def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
+        """
+        Initialize a scale transformer.
+
+        :param scale_inputs: the scale multiplied to the inputs
+        :param scale_outputs: the scale multiplied to the outputs
+        """
+        self.scale_inputs = scale_inputs
+        self.scale_outputs = scale_outputs
+
+    def __call__(
+        self, inputs: Tensor, outputs_expected: Tensor
+    ) -> Tuple[Tensor, Tensor]:
+        """
+        Scale the inputs and the outputs by the factors defined in the constructor.
+        """
+        return inputs * self.scale_inputs, outputs_expected * self.scale_outputs
-- 
GitLab