from typing import List, Tuple from torch import Tensor class Transform: def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: return inputs, outputs_expected class SequenceTransform(Transform): def __init__(self, transforms: List[Transform]): self.transforms = transforms def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: for transforms in self.transforms: inputs, outputs_expected = transforms(inputs, outputs_expected) return inputs, outputs_expected