Newer
Older
from typing import List, Tuple
from torch import Tensor
class Transform:
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
return inputs, outputs_expected
class SequenceTransform(Transform):
def __init__(self, transforms: List[Transform]):
self.transforms = transforms
def __call__(
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
for transforms in self.transforms:
inputs, outputs_expected = transforms(inputs, outputs_expected)
return inputs, outputs_expected