Skip to content
Snippets Groups Projects
Commit 291180a0 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

remove outdated scale transform and refactor to support variable number of tensors

parent fe448f7a
No related branches found
No related tags found
No related merge requests found
......@@ -12,11 +12,11 @@ class Transform:
used for normalization and data augmentation.
"""
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
"""
Apply the transformer to a pair of inputs and expected outputs in a dataset.
"""
return inputs, targets
return tensors
class SequenceTransform(Transform):
......@@ -27,33 +27,12 @@ class SequenceTransform(Transform):
def __init__(self, transforms: List[Transform]):
self.transforms = transforms
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
for transforms in self.transforms:
inputs, targets = transforms(inputs, targets)
return inputs, targets
tensors = transforms(*tensors)
return tensors
class ScaleTransform(Transform):
"""
A transformer that scales the inputs and outputs by pre-defined factors.
"""
def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
"""
Initialize a scale transformer.
:param scale_inputs: the scale multiplied to the inputs
:param scale_outputs: the scale multiplied to the outputs
"""
self.scale_inputs = scale_inputs
self.scale_outputs = scale_outputs
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Scale the inputs and the outputs by the factors defined in the constructor.
"""
return inputs * self.scale_inputs, targets * self.scale_outputs
class PaddingTransform(Transform):
"""
......@@ -68,12 +47,12 @@ class PaddingTransform(Transform):
self.dim = dim
self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
"""
Pad inputs and targets so that dimension self.dim has at
least a size of self.size.
"""
return self.pad(inputs), self.pad(targets)
return tuple(map(lambda tensor: self.pad(tensor), tensors))
def pad(self, inputs: Tensor):
"""
......@@ -104,12 +83,12 @@ class CroppingTransform(Transform):
self.dim = dim
self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
"""
Crop inputs and targets so that dimension self.dim has at
most a size of self.size.
"""
return self.crop(inputs), self.crop(targets)
return tuple(map(lambda tensor: self.crop(tensor), tensors))
def crop(self, inputs: Tensor):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment