from typing import List, Tuple


from torch import Tensor


class Transform:
    """
    Interface of a transformer. A transformer can be initialized and then applied to
    an input tensor and expected output tensor as returned by a dataset. It can be
    used for normalization and data augmentation.
    """

    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """
        Apply the transformer to a pair of inputs and expected outputs in a dataset.
        """
        return inputs, outputs_expected


class SequenceTransform(Transform):
    """
    A transformer that applies a sequence of transformers sequentially.
    """

    def __init__(self, transforms: List[Transform]):
        self.transforms = transforms

    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        for transforms in self.transforms:
            inputs, outputs_expected = transforms(inputs, outputs_expected)
        return inputs, outputs_expected


class ScaleTransform(Transform):
    """
    A transformer that scales the inputs and outputs by pre-defined factors.
    """

    def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
        """
        Initialize a scale transformer.

        :param scale_inputs: the scale multiplied to the inputs
        :param scale_outputs: the scale multiplied to the outputs
        """
        self.scale_inputs = scale_inputs
        self.scale_outputs = scale_outputs

    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        """
        Scale the inputs and the outputs by the factors defined in the constructor.
        """
        return inputs * self.scale_inputs, outputs_expected * self.scale_outputs