diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py
index 9e34995bdcd90ae384529b4ae576817a43b947f8..e2fad43adf45ea4788c97e7e9ae54d3f7337ef68 100644
--- a/mu_map/dataset/transform.py
+++ b/mu_map/dataset/transform.py
@@ -12,11 +12,11 @@ class Transform:
     used for normalization and data augmentation.
     """
 
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
+    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
         """
         Apply the transformer to a pair of inputs and expected outputs in a dataset.
         """
-        return inputs, targets
+        return tensors
 
 
 class SequenceTransform(Transform):
@@ -27,33 +27,12 @@ class SequenceTransform(Transform):
     def __init__(self, transforms: List[Transform]):
         self.transforms = transforms
 
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
+    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
         for transforms in self.transforms:
-            inputs, targets = transforms(inputs, targets)
-        return inputs, targets
+            tensors = transforms(*tensors)
+        return tensors
 
 
-class ScaleTransform(Transform):
-    """
-    A transformer that scales the inputs and outputs by pre-defined factors.
-    """
-
-    def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
-        """
-        Initialize a scale transformer.
-
-        :param scale_inputs: the scale multiplied to the inputs
-        :param scale_outputs: the scale multiplied to the outputs
-        """
-        self.scale_inputs = scale_inputs
-        self.scale_outputs = scale_outputs
-
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
-        """
-        Scale the inputs and the outputs by the factors defined in the constructor.
-        """
-        return inputs * self.scale_inputs, targets * self.scale_outputs
-
 
 class PaddingTransform(Transform):
     """
@@ -68,12 +47,12 @@ class PaddingTransform(Transform):
         self.dim = dim
         self.size = size
 
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
+    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
         """
         Pad inputs and targets so that dimension self.dim has at
         least a size of self.size.
         """
-        return self.pad(inputs), self.pad(targets)
+        return tuple(map(lambda tensor: self.pad(tensor), tensors))
 
     def pad(self, inputs: Tensor):
         """
@@ -104,12 +83,12 @@ class CroppingTransform(Transform):
         self.dim = dim
         self.size = size
 
-    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
+    def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
         """
         Crop inputs and targets so that dimension self.dim has at
         most a size of self.size.
         """
-        return self.crop(inputs), self.crop(targets)
+        return tuple(map(lambda tensor: self.crop(tensor), tensors))
 
     def crop(self, inputs: Tensor):
         """