Skip to content
Snippets Groups Projects
transform.py 5.32 KiB
Newer Older
  • Learn to ignore specific revisions
  • from typing import List, Tuple
    
    
    from torch import Tensor
    
    
    class Transform:
    
        """
        Interface of a transformer. A transformer can be initialized and then applied to
        an input tensor and expected output tensor as returned by a dataset. It can be
        used for normalization and data augmentation.
        """
    
    
        def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
    
            """
            Apply the transformer to a pair of inputs and expected outputs in a dataset.
            """
    
            return inputs, targets
    
    
    
    class SequenceTransform(Transform):
    
        """
        A transformer that applies a sequence of transformers sequentially.
        """
    
    
        def __init__(self, transforms: List[Transform]):
            self.transforms = transforms
    
    
        def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
    
            for transforms in self.transforms:
    
                inputs, targets = transforms(inputs, targets)
            return inputs, targets
    
    
    
    class ScaleTransform(Transform):
        """
        A transformer that scales the inputs and outputs by pre-defined factors.
        """
    
        def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
            """
            Initialize a scale transformer.
    
            :param scale_inputs: the scale multiplied to the inputs
            :param scale_outputs: the scale multiplied to the outputs
            """
            self.scale_inputs = scale_inputs
            self.scale_outputs = scale_outputs
    
    
        def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
    
            """
            Scale the inputs and the outputs by the factors defined in the constructor.
            """
    
            return inputs * self.scale_inputs, targets * self.scale_outputs
    
    
    class PaddingTransform(Transform):
        """
        A transformer that pads a specified dimension of tensors
        so that they have at least a given size.
    
        :param dim: the dimension to be padded (from behind, see torch.nn.functional.pad)
        :param size: the size to which the dimension should be padded if it is smaller
        """
    
        def __init__(self, dim: int, size: int):
            self.dim = dim
            self.size = size
    
        def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
            """
            Pad inputs and targets so that dimension self.dim has at
            least a size of self.size.
            """
            return self.pad(inputs), self.pad(targets)
    
        def pad(self, inputs: Tensor):
            """
            Pad a single input tensor so that dimension self.dim has at
            least a size of self.size.
            """
            shape_idx = len(inputs.shape) - self.dim
            if inputs.shape[shape_idx] >= self.size:
                return inputs
    
            diff_half = (self.size - inputs.shape[shape_idx]) / 2
            padding = [0] * 2 * self.dim
            padding[-2] = math.ceil(diff_half)
            padding[-1] = math.floor(diff_half)
            return torch.nn.functional.pad(inputs, padding, mode="constant", value=0)
    
    
    class CroppingTransform(Transform):
        """
        A transformer that crops a specified dimension of tensors
        so that they have at most the given size.
    
        :param dim: the dimension to be cropped (from behind, see PaddingTransform)
        :param size: the size to which the dimension should be cropped if it is larger
        """
    
        def __init__(self, dim: int, size: int):
            self.dim = dim
            self.size = size
    
        def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
            """
            Crop inputs and targets so that dimension self.dim has at
            most a size of self.size.
            """
            return self.crop(inputs), self.crop(targets)
    
        def crop(self, inputs: Tensor):
            """
            Crop a single input tensor so that dimension self.dim has at
            most a size of self.size.
            """
            shape_idx = len(inputs.shape) - self.dim
            if inputs.shape[shape_idx] <= self.size:
                return inputs
    
            # create slices selecting everything up to shape_idx
            slices = map(slice, inputs.shape[:shape_idx])
            slices = list(slices)
            # add slice which performs the crop on the specified dimension
            center = inputs.shape[shape_idx] // 2
            size_half = self.size / 2
            slices.append(
                slice(center - math.ceil(size_half), center + math.floor(size_half))
            )
            return inputs[slices]
    
    
    class PadCropTranform(SequenceTransform):
        """
        A combination of padding and cropping that makes sure that a
        specified dimension always has a given size.
    
        :param dim: the dimension to be padded and cropped (from behind, see PaddingTransform)
        :param size: the size to which the dimension should be padded and cropped
        """
    
        def __init__(self, dim: int, size: int):
            super().__init__(
                transforms=[PaddingTransform(dim, size), CroppingTransform(dim, size)]
            )
    
    
    if __name__ == "__main__":
        transform = PadCropTranform(dim=3, size=32)
    
        shape = (8, 1, 29, 128, 128)
        inputs = torch.rand(shape)
        targets = torch.rand(shape)
        inputs, targets = transform(inputs, targets)
        assert inputs.shape[2] == 32
        assert targets.shape[2] == 32
        print(inputs.shape)
    
        shape = (8, 1, 45, 128, 128)
        inputs = torch.rand(shape)
        targets = torch.rand(shape)
        inputs, targets = transform(inputs, targets)
        assert inputs.shape[2] == 32
        assert targets.shape[2] == 32
        print(inputs.shape)