from typing import List, Tuple


from torch import Tensor


class Transform:
    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        return inputs, outputs_expected


class SequenceTransform(Transform):
    def __init__(self, transforms: List[Transform]):
        self.transforms = transforms

    def __call__(
        self, inputs: Tensor, outputs_expected: Tensor
    ) -> Tuple[Tensor, Tensor]:
        for transforms in self.transforms:
            inputs, outputs_expected = transforms(inputs, outputs_expected)
        return inputs, outputs_expected