diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py index 9e34995bdcd90ae384529b4ae576817a43b947f8..e2fad43adf45ea4788c97e7e9ae54d3f7337ef68 100644 --- a/mu_map/dataset/transform.py +++ b/mu_map/dataset/transform.py @@ -12,11 +12,11 @@ class Transform: used for normalization and data augmentation. """ - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: """ Apply the transformer to a pair of inputs and expected outputs in a dataset. """ - return inputs, targets + return tensors class SequenceTransform(Transform): @@ -27,33 +27,12 @@ class SequenceTransform(Transform): def __init__(self, transforms: List[Transform]): self.transforms = transforms - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: for transforms in self.transforms: - inputs, targets = transforms(inputs, targets) - return inputs, targets + tensors = transforms(*tensors) + return tensors -class ScaleTransform(Transform): - """ - A transformer that scales the inputs and outputs by pre-defined factors. - """ - - def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0): - """ - Initialize a scale transformer. - - :param scale_inputs: the scale multiplied to the inputs - :param scale_outputs: the scale multiplied to the outputs - """ - self.scale_inputs = scale_inputs - self.scale_outputs = scale_outputs - - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: - """ - Scale the inputs and the outputs by the factors defined in the constructor. - """ - return inputs * self.scale_inputs, targets * self.scale_outputs - class PaddingTransform(Transform): """ @@ -68,12 +47,12 @@ class PaddingTransform(Transform): self.dim = dim self.size = size - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: """ Pad inputs and targets so that dimension self.dim has at least a size of self.size. """ - return self.pad(inputs), self.pad(targets) + return tuple(map(lambda tensor: self.pad(tensor), tensors)) def pad(self, inputs: Tensor): """ @@ -104,12 +83,12 @@ class CroppingTransform(Transform): self.dim = dim self.size = size - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: """ Crop inputs and targets so that dimension self.dim has at most a size of self.size. """ - return self.crop(inputs), self.crop(targets) + return tuple(map(lambda tensor: self.crop(tensor), tensors)) def crop(self, inputs: Tensor): """