diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..c27aaa57dfa7734a3bbf4b59a9d17b2a5b44c6c9 --- /dev/null +++ b/mu_map/dataset/transform.py @@ -0,0 +1,23 @@ +from typing import List, Tuple + + +from torch import Tensor + + +class Transform: + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + return inputs, outputs_expected + + +class SequenceTransform(Transform): + def __init__(self, transforms: List[Transform]): + self.transforms = transforms + + def __call__( + self, inputs: Tensor, outputs_expected: Tensor + ) -> Tuple[Tensor, Tensor]: + for transforms in self.transforms: + inputs, outputs_expected = transforms(inputs, outputs_expected) + return inputs, outputs_expected