From 291180a01880a55520186d470ca5a8a2a94a2cf5 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Wed, 21 Dec 2022 14:26:51 +0100 Subject: [PATCH] remove outdated scale transform and refactor to support variable number of tensors --- mu_map/dataset/transform.py | 39 +++++++++---------------------------- 1 file changed, 9 insertions(+), 30 deletions(-) diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py index 9e34995..e2fad43 100644 --- a/mu_map/dataset/transform.py +++ b/mu_map/dataset/transform.py @@ -12,11 +12,11 @@ class Transform: used for normalization and data augmentation. """ - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: """ Apply the transformer to a pair of inputs and expected outputs in a dataset. """ - return inputs, targets + return tensors class SequenceTransform(Transform): @@ -27,33 +27,12 @@ class SequenceTransform(Transform): def __init__(self, transforms: List[Transform]): self.transforms = transforms - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: for transforms in self.transforms: - inputs, targets = transforms(inputs, targets) - return inputs, targets + tensors = transforms(*tensors) + return tensors -class ScaleTransform(Transform): - """ - A transformer that scales the inputs and outputs by pre-defined factors. - """ - - def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0): - """ - Initialize a scale transformer. - - :param scale_inputs: the scale multiplied to the inputs - :param scale_outputs: the scale multiplied to the outputs - """ - self.scale_inputs = scale_inputs - self.scale_outputs = scale_outputs - - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: - """ - Scale the inputs and the outputs by the factors defined in the constructor. - """ - return inputs * self.scale_inputs, targets * self.scale_outputs - class PaddingTransform(Transform): """ @@ -68,12 +47,12 @@ class PaddingTransform(Transform): self.dim = dim self.size = size - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: """ Pad inputs and targets so that dimension self.dim has at least a size of self.size. """ - return self.pad(inputs), self.pad(targets) + return tuple(map(lambda tensor: self.pad(tensor), tensors)) def pad(self, inputs: Tensor): """ @@ -104,12 +83,12 @@ class CroppingTransform(Transform): self.dim = dim self.size = size - def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: + def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]: """ Crop inputs and targets so that dimension self.dim has at most a size of self.size. """ - return self.crop(inputs), self.crop(targets) + return tuple(map(lambda tensor: self.crop(tensor), tensors)) def crop(self, inputs: Tensor): """ -- GitLab