Skip to content
Snippets Groups Projects
transform.py 615 B
Newer Older
  • Learn to ignore specific revisions
  • from typing import List, Tuple
    
    
    from torch import Tensor
    
    
    class Transform:
        def __call__(
            self, inputs: Tensor, outputs_expected: Tensor
        ) -> Tuple[Tensor, Tensor]:
            return inputs, outputs_expected
    
    
    class SequenceTransform(Transform):
        def __init__(self, transforms: List[Transform]):
            self.transforms = transforms
    
        def __call__(
            self, inputs: Tensor, outputs_expected: Tensor
        ) -> Tuple[Tensor, Tensor]:
            for transforms in self.transforms:
                inputs, outputs_expected = transforms(inputs, outputs_expected)
            return inputs, outputs_expected