diff --git a/mu_map/dataset/transform.py b/mu_map/dataset/transform.py index c27aaa57dfa7734a3bbf4b59a9d17b2a5b44c6c9..7b2685e0cebd568b611c4ac4bf22f69304906b33 100644 --- a/mu_map/dataset/transform.py +++ b/mu_map/dataset/transform.py @@ -5,13 +5,26 @@ from torch import Tensor class Transform: + """ + Interface of a transformer. A transformer can be initialized and then applied to + an input tensor and expected output tensor as returned by a dataset. It can be + used for normalization and data augmentation. + """ + def __call__( self, inputs: Tensor, outputs_expected: Tensor ) -> Tuple[Tensor, Tensor]: + """ + Apply the transformer to a pair of inputs and expected outputs in a dataset. + """ return inputs, outputs_expected class SequenceTransform(Transform): + """ + A transformer that applies a sequence of transformers sequentially. + """ + def __init__(self, transforms: List[Transform]): self.transforms = transforms