From e95e5b6f7c7182de13d88f75f74199445ac808b0 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 4 Oct 2022 09:44:03 +0200 Subject: [PATCH] implement a scale tranform --- mu_map/dataset/transform.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py index 7b2685e..a991398 100644 --- a/mu_map/dataset/transform.py +++ b/mu_map/dataset/transform.py @@ -34,3 +34,27 @@ class SequenceTransform(Transform): for transforms in self.transforms: inputs, outputs_expected = transforms(inputs, outputs_expected) return inputs, outputs_expected + + +class ScaleTransform(Transform): + """ + A transformer that scales the inputs and outputs by pre-defined factors. + """ + + def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0): + """ + Initialize a scale transformer. + + :param scale_inputs: the scale multiplied to the inputs + :param scale_outputs: the scale multiplied to the outputs + """ + self.scale_inputs = scale_inputs + self.scale_outputs = scale_outputs + + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Scale the inputs and the outputs by the factors defined in the constructor. + """ + return inputs * self.scale_inputs, outputs_expected * self.scale_outputs -- GitLab