Newer
Older
from typing import List, Tuple
from torch import Tensor
class Transform:
"""
Interface of a transformer. A transformer can be initialized and then applied to
an input tensor and expected output tensor as returned by a dataset. It can be
used for normalization and data augmentation.
"""
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Apply the transformer to a pair of inputs and expected outputs in a dataset.
"""
return inputs, outputs_expected
class SequenceTransform(Transform):
"""
A transformer that applies a sequence of transformers sequentially.
"""
def __init__(self, transforms: List[Transform]):
self.transforms = transforms
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
for transforms in self.transforms:
inputs, outputs_expected = transforms(inputs, outputs_expected)
return inputs, outputs_expected
class ScaleTransform(Transform):
"""
A transformer that scales the inputs and outputs by pre-defined factors.
"""
def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
"""
Initialize a scale transformer.
:param scale_inputs: the scale multiplied to the inputs
:param scale_outputs: the scale multiplied to the outputs
"""
self.scale_inputs = scale_inputs
self.scale_outputs = scale_outputs
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
"""
Scale the inputs and the outputs by the factors defined in the constructor.
"""
return inputs * self.scale_inputs, outputs_expected * self.scale_outputs