import math from typing import List, Tuple import torch from torch import Tensor class Transform: """ Interface of a transformer. A transformer can be initialized and then applied to an input tensor and expected output tensor as returned by a dataset. It can be used for normalization and data augmentation. """ def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: """ Apply the transformer to a pair of inputs and expected outputs in a dataset. """ return inputs, targets class SequenceTransform(Transform): """ A transformer that applies a sequence of transformers sequentially. """ def __init__(self, transforms: List[Transform]): self.transforms = transforms def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: for transforms in self.transforms: inputs, targets = transforms(inputs, targets) return inputs, targets class ScaleTransform(Transform): """ A transformer that scales the inputs and outputs by pre-defined factors. """ def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0): """ Initialize a scale transformer. :param scale_inputs: the scale multiplied to the inputs :param scale_outputs: the scale multiplied to the outputs """ self.scale_inputs = scale_inputs self.scale_outputs = scale_outputs def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: """ Scale the inputs and the outputs by the factors defined in the constructor. """ return inputs * self.scale_inputs, targets * self.scale_outputs class PaddingTransform(Transform): """ A transformer that pads a specified dimension of tensors so that they have at least a given size. :param dim: the dimension to be padded (from behind, see torch.nn.functional.pad) :param size: the size to which the dimension should be padded if it is smaller """ def __init__(self, dim: int, size: int): self.dim = dim self.size = size def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: """ Pad inputs and targets so that dimension self.dim has at least a size of self.size. """ return self.pad(inputs), self.pad(targets) def pad(self, inputs: Tensor): """ Pad a single input tensor so that dimension self.dim has at least a size of self.size. """ shape_idx = len(inputs.shape) - self.dim if inputs.shape[shape_idx] >= self.size: return inputs diff_half = (self.size - inputs.shape[shape_idx]) / 2 padding = [0] * 2 * self.dim padding[-2] = math.ceil(diff_half) padding[-1] = math.floor(diff_half) return torch.nn.functional.pad(inputs, padding, mode="constant", value=0) class CroppingTransform(Transform): """ A transformer that crops a specified dimension of tensors so that they have at most the given size. :param dim: the dimension to be cropped (from behind, see PaddingTransform) :param size: the size to which the dimension should be cropped if it is larger """ def __init__(self, dim: int, size: int): self.dim = dim self.size = size def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]: """ Crop inputs and targets so that dimension self.dim has at most a size of self.size. """ return self.crop(inputs), self.crop(targets) def crop(self, inputs: Tensor): """ Crop a single input tensor so that dimension self.dim has at most a size of self.size. """ shape_idx = len(inputs.shape) - self.dim if inputs.shape[shape_idx] <= self.size: return inputs # create slices selecting everything up to shape_idx slices = map(slice, inputs.shape[:shape_idx]) slices = list(slices) # add slice which performs the crop on the specified dimension center = inputs.shape[shape_idx] // 2 size_half = self.size / 2 slices.append( slice(center - math.ceil(size_half), center + math.floor(size_half)) ) return inputs[slices] class PadCropTranform(SequenceTransform): """ A combination of padding and cropping that makes sure that a specified dimension always has a given size. :param dim: the dimension to be padded and cropped (from behind, see PaddingTransform) :param size: the size to which the dimension should be padded and cropped """ def __init__(self, dim: int, size: int): super().__init__( transforms=[PaddingTransform(dim, size), CroppingTransform(dim, size)] ) if __name__ == "__main__": transform = PadCropTranform(dim=3, size=32) shape = (8, 1, 29, 128, 128) inputs = torch.rand(shape) targets = torch.rand(shape) inputs, targets = transform(inputs, targets) assert inputs.shape[2] == 32 assert targets.shape[2] == 32 print(inputs.shape) shape = (8, 1, 45, 128, 128) inputs = torch.rand(shape) targets = torch.rand(shape) inputs, targets = transform(inputs, targets) assert inputs.shape[2] == 32 assert targets.shape[2] == 32 print(inputs.shape)