From 291180a01880a55520186d470ca5a8a2a94a2cf5 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Wed, 21 Dec 2022 14:26:51 +0100
Subject: [PATCH] remove outdated scale transform and refactor to support
 variable number of tensors

---
 mu_map/dataset/transform.py | 39 +++++++++----------------------------
 1 file changed, 9 insertions(+), 30 deletions(-)

diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py
index 9e34995..e2fad43 100644
--- a/mu_map/dataset/transform.py
+++ b/mu_map/dataset/transform.py
@@ -12,11 +12,11 @@ class Transform:
     used for normalization and data augmentation.
     """
 
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
+    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
         """
         Apply the transformer to a pair of inputs and expected outputs in a dataset.
         """
-        return inputs, targets
+        return tensors
 
 
 class SequenceTransform(Transform):
@@ -27,33 +27,12 @@ class SequenceTransform(Transform):
     def __init__(self, transforms: List[Transform]):
         self.transforms = transforms
 
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
+    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
         for transforms in self.transforms:
-            inputs, targets = transforms(inputs, targets)
-        return inputs, targets
+            tensors = transforms(*tensors)
+        return tensors
 
 
-class ScaleTransform(Transform):
-    """
-    A transformer that scales the inputs and outputs by pre-defined factors.
-    """
-
-    def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
-        """
-        Initialize a scale transformer.
-
-        :param scale_inputs: the scale multiplied to the inputs
-        :param scale_outputs: the scale multiplied to the outputs
-        """
-        self.scale_inputs = scale_inputs
-        self.scale_outputs = scale_outputs
-
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
-        """
-        Scale the inputs and the outputs by the factors defined in the constructor.
-        """
-        return inputs * self.scale_inputs, targets * self.scale_outputs
-
 
 class PaddingTransform(Transform):
     """
@@ -68,12 +47,12 @@ class PaddingTransform(Transform):
         self.dim = dim
         self.size = size
 
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
+    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
         """
         Pad inputs and targets so that dimension self.dim has at
         least a size of self.size.
         """
-        return self.pad(inputs), self.pad(targets)
+        return tuple(map(lambda tensor: self.pad(tensor), tensors))
 
     def pad(self, inputs: Tensor):
         """
@@ -104,12 +83,12 @@ class CroppingTransform(Transform):
         self.dim = dim
         self.size = size
 
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
+    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
         """
         Crop inputs and targets so that dimension self.dim has at
         most a size of self.size.
         """
-        return self.crop(inputs), self.crop(targets)
+        return tuple(map(lambda tensor: self.crop(tensor), tensors))
 
     def crop(self, inputs: Tensor):
         """
-- 
GitLab