from typing import List, Tuple from torch import Tensor class Transform: """ Interface of a transformer. A transformer can be initialized and then applied to an input tensor and expected output tensor as returned by a dataset. It can be used for normalization and data augmentation. """ def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: """ Apply the transformer to a pair of inputs and expected outputs in a dataset. """ return inputs, outputs_expected class SequenceTransform(Transform): """ A transformer that applies a sequence of transformers sequentially. """ def __init__(self, transforms: List[Transform]): self.transforms = transforms def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: for transforms in self.transforms: inputs, outputs_expected = transforms(inputs, outputs_expected) return inputs, outputs_expected class ScaleTransform(Transform): """ A transformer that scales the inputs and outputs by pre-defined factors. """ def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0): """ Initialize a scale transformer. :param scale_inputs: the scale multiplied to the inputs :param scale_outputs: the scale multiplied to the outputs """ self.scale_inputs = scale_inputs self.scale_outputs = scale_outputs def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: """ Scale the inputs and the outputs by the factors defined in the constructor. """ return inputs * self.scale_inputs, outputs_expected * self.scale_outputs