diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py index 7b2685e0cebd568b611c4ac4bf22f69304906b33..a99139827280befef7541dcc4374b5c733010518 100644 --- a/mu_map/dataset/transform.py +++ b/mu_map/dataset/transform.py @@ -34,3 +34,27 @@ class SequenceTransform(Transform): for transforms in self.transforms: inputs, outputs_expected = transforms(inputs, outputs_expected) return inputs, outputs_expected + + +class ScaleTransform(Transform): + """ + A transformer that scales the inputs and outputs by pre-defined factors. + """ + + def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0): + """ + Initialize a scale transformer. + + :param scale_inputs: the scale multiplied to the inputs + :param scale_outputs: the scale multiplied to the outputs + """ + self.scale_inputs = scale_inputs + self.scale_outputs = scale_outputs + + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Scale the inputs and the outputs by the factors defined in the constructor. + """ + return inputs * self.scale_inputs, outputs_expected * self.scale_outputs