Skip to content
Snippets Groups Projects
transform.py 5.32 KiB
import math
from typing import List, Tuple

import torch
from torch import Tensor


class Transform:
    """
    Interface of a transformer. A transformer can be initialized and then applied to
    an input tensor and expected output tensor as returned by a dataset. It can be
    used for normalization and data augmentation.
    """

    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Apply the transformer to a pair of inputs and expected outputs in a dataset.
        """
        return inputs, targets


class SequenceTransform(Transform):
    """
    A transformer that applies a sequence of transformers sequentially.
    """

    def __init__(self, transforms: List[Transform]):
        self.transforms = transforms

    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
        for transforms in self.transforms:
            inputs, targets = transforms(inputs, targets)
        return inputs, targets


class ScaleTransform(Transform):
    """
    A transformer that scales the inputs and outputs by pre-defined factors.
    """

    def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
        """
        Initialize a scale transformer.

        :param scale_inputs: the scale multiplied to the inputs
        :param scale_outputs: the scale multiplied to the outputs
        """
        self.scale_inputs = scale_inputs
        self.scale_outputs = scale_outputs

    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Scale the inputs and the outputs by the factors defined in the constructor.
        """
        return inputs * self.scale_inputs, targets * self.scale_outputs


class PaddingTransform(Transform):
    """
    A transformer that pads a specified dimension of tensors
    so that they have at least a given size.

    :param dim: the dimension to be padded (from behind, see torch.nn.functional.pad)
    :param size: the size to which the dimension should be padded if it is smaller
    """

    def __init__(self, dim: int, size: int):
        self.dim = dim
        self.size = size

    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Pad inputs and targets so that dimension self.dim has at
        least a size of self.size.
        """
        return self.pad(inputs), self.pad(targets)

    def pad(self, inputs: Tensor):
        """
        Pad a single input tensor so that dimension self.dim has at
        least a size of self.size.
        """
        shape_idx = len(inputs.shape) - self.dim
        if inputs.shape[shape_idx] >= self.size:
            return inputs

        diff_half = (self.size - inputs.shape[shape_idx]) / 2
        padding = [0] * 2 * self.dim
        padding[-2] = math.ceil(diff_half)
        padding[-1] = math.floor(diff_half)
        return torch.nn.functional.pad(inputs, padding, mode="constant", value=0)


class CroppingTransform(Transform):
    """
    A transformer that crops a specified dimension of tensors
    so that they have at most the given size.

    :param dim: the dimension to be cropped (from behind, see PaddingTransform)
    :param size: the size to which the dimension should be cropped if it is larger
    """

    def __init__(self, dim: int, size: int):
        self.dim = dim
        self.size = size

    def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Crop inputs and targets so that dimension self.dim has at
        most a size of self.size.
        """
        return self.crop(inputs), self.crop(targets)

    def crop(self, inputs: Tensor):
        """
        Crop a single input tensor so that dimension self.dim has at
        most a size of self.size.
        """
        shape_idx = len(inputs.shape) - self.dim
        if inputs.shape[shape_idx] <= self.size:
            return inputs

        # create slices selecting everything up to shape_idx
        slices = map(slice, inputs.shape[:shape_idx])
        slices = list(slices)
        # add slice which performs the crop on the specified dimension
        center = inputs.shape[shape_idx] // 2
        size_half = self.size / 2
        slices.append(
            slice(center - math.ceil(size_half), center + math.floor(size_half))
        )
        return inputs[slices]


class PadCropTranform(SequenceTransform):
    """
    A combination of padding and cropping that makes sure that a
    specified dimension always has a given size.

    :param dim: the dimension to be padded and cropped (from behind, see PaddingTransform)
    :param size: the size to which the dimension should be padded and cropped
    """

    def __init__(self, dim: int, size: int):
        super().__init__(
            transforms=[PaddingTransform(dim, size), CroppingTransform(dim, size)]
        )


if __name__ == "__main__":
    transform = PadCropTranform(dim=3, size=32)

    shape = (8, 1, 29, 128, 128)
    inputs = torch.rand(shape)
    targets = torch.rand(shape)
    inputs, targets = transform(inputs, targets)
    assert inputs.shape[2] == 32
    assert targets.shape[2] == 32
    print(inputs.shape)

    shape = (8, 1, 45, 128, 128)
    inputs = torch.rand(shape)
    targets = torch.rand(shape)
    inputs, targets = transform(inputs, targets)
    assert inputs.shape[2] == 32
    assert targets.shape[2] == 32
    print(inputs.shape)