Skip to content
Snippets Groups Projects
transform.py 1.79 KiB
Newer Older
  • Learn to ignore specific revisions
  • from typing import List, Tuple
    
    
    from torch import Tensor
    
    
    class Transform:
    
        """
        Interface of a transformer. A transformer can be initialized and then applied to
        an input tensor and expected output tensor as returned by a dataset. It can be
        used for normalization and data augmentation.
        """
    
    
        def __call__(
            self, inputs: Tensor, outputs_expected: Tensor
        ) -> Tuple[Tensor, Tensor]:
    
            """
            Apply the transformer to a pair of inputs and expected outputs in a dataset.
            """
    
            return inputs, outputs_expected
    
    
    class SequenceTransform(Transform):
    
        """
        A transformer that applies a sequence of transformers sequentially.
        """
    
    
        def __init__(self, transforms: List[Transform]):
            self.transforms = transforms
    
        def __call__(
            self, inputs: Tensor, outputs_expected: Tensor
        ) -> Tuple[Tensor, Tensor]:
            for transforms in self.transforms:
                inputs, outputs_expected = transforms(inputs, outputs_expected)
            return inputs, outputs_expected
    
    
    
    class ScaleTransform(Transform):
        """
        A transformer that scales the inputs and outputs by pre-defined factors.
        """
    
        def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
            """
            Initialize a scale transformer.
    
            :param scale_inputs: the scale multiplied to the inputs
            :param scale_outputs: the scale multiplied to the outputs
            """
            self.scale_inputs = scale_inputs
            self.scale_outputs = scale_outputs
    
        def __call__(
            self, inputs: Tensor, outputs_expected: Tensor
        ) -> Tuple[Tensor, Tensor]:
            """
            Scale the inputs and the outputs by the factors defined in the constructor.
            """
            return inputs * self.scale_inputs, outputs_expected * self.scale_outputs