Skip to content
Snippets Groups Projects
Commit 291180a0 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

remove outdated scale transform and refactor to support variable number of tensors

parent fe448f7a
No related branches found
No related tags found
No related merge requests found
...@@ -12,11 +12,11 @@ class Transform: ...@@ -12,11 +12,11 @@ class Transform:
used for normalization and data augmentation. used for normalization and data augmentation.
""" """
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
""" """
Apply the transformer to a pair of inputs and expected outputs in a dataset. Apply the transformer to a pair of inputs and expected outputs in a dataset.
""" """
return inputs, targets return tensors
class SequenceTransform(Transform): class SequenceTransform(Transform):
...@@ -27,33 +27,12 @@ class SequenceTransform(Transform): ...@@ -27,33 +27,12 @@ class SequenceTransform(Transform):
def __init__(self, transforms: List[Transform]): def __init__(self, transforms: List[Transform]):
self.transforms = transforms self.transforms = transforms
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
for transforms in self.transforms: for transforms in self.transforms:
inputs, targets = transforms(inputs, targets) tensors = transforms(*tensors)
return inputs, targets return tensors
class ScaleTransform(Transform):
"""
A transformer that scales the inputs and outputs by pre-defined factors.
"""
def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
"""
Initialize a scale transformer.
:param scale_inputs: the scale multiplied to the inputs
:param scale_outputs: the scale multiplied to the outputs
"""
self.scale_inputs = scale_inputs
self.scale_outputs = scale_outputs
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Scale the inputs and the outputs by the factors defined in the constructor.
"""
return inputs * self.scale_inputs, targets * self.scale_outputs
class PaddingTransform(Transform): class PaddingTransform(Transform):
""" """
...@@ -68,12 +47,12 @@ class PaddingTransform(Transform): ...@@ -68,12 +47,12 @@ class PaddingTransform(Transform):
self.dim = dim self.dim = dim
self.size = size self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
""" """
Pad inputs and targets so that dimension self.dim has at Pad inputs and targets so that dimension self.dim has at
least a size of self.size. least a size of self.size.
""" """
return self.pad(inputs), self.pad(targets) return tuple(map(lambda tensor: self.pad(tensor), tensors))
def pad(self, inputs: Tensor): def pad(self, inputs: Tensor):
""" """
...@@ -104,12 +83,12 @@ class CroppingTransform(Transform): ...@@ -104,12 +83,12 @@ class CroppingTransform(Transform):
self.dim = dim self.dim = dim
self.size = size self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: def __call__(self, *tensors: Tensor) -> Tuple[Tensor, ...]:
""" """
Crop inputs and targets so that dimension self.dim has at Crop inputs and targets so that dimension self.dim has at
most a size of self.size. most a size of self.size.
""" """
return self.crop(inputs), self.crop(targets) return tuple(map(lambda tensor: self.crop(tensor), tensors))
def crop(self, inputs: Tensor): def crop(self, inputs: Tensor):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment