-
Tamino Huxohl authoredTamino Huxohl authored
transform.py 5.32 KiB
import math
from typing import List, Tuple
import torch
from torch import Tensor
class Transform:
"""
Interface of a transformer. A transformer can be initialized and then applied to
an input tensor and expected output tensor as returned by a dataset. It can be
used for normalization and data augmentation.
"""
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Apply the transformer to a pair of inputs and expected outputs in a dataset.
"""
return inputs, targets
class SequenceTransform(Transform):
"""
A transformer that applies a sequence of transformers sequentially.
"""
def __init__(self, transforms: List[Transform]):
self.transforms = transforms
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
for transforms in self.transforms:
inputs, targets = transforms(inputs, targets)
return inputs, targets
class ScaleTransform(Transform):
"""
A transformer that scales the inputs and outputs by pre-defined factors.
"""
def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
"""
Initialize a scale transformer.
:param scale_inputs: the scale multiplied to the inputs
:param scale_outputs: the scale multiplied to the outputs
"""
self.scale_inputs = scale_inputs
self.scale_outputs = scale_outputs
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Scale the inputs and the outputs by the factors defined in the constructor.
"""
return inputs * self.scale_inputs, targets * self.scale_outputs
class PaddingTransform(Transform):
"""
A transformer that pads a specified dimension of tensors
so that they have at least a given size.
:param dim: the dimension to be padded (from behind, see torch.nn.functional.pad)
:param size: the size to which the dimension should be padded if it is smaller
"""
def __init__(self, dim: int, size: int):
self.dim = dim
self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Pad inputs and targets so that dimension self.dim has at
least a size of self.size.
"""
return self.pad(inputs), self.pad(targets)
def pad(self, inputs: Tensor):
"""
Pad a single input tensor so that dimension self.dim has at
least a size of self.size.
"""
shape_idx = len(inputs.shape) - self.dim
if inputs.shape[shape_idx] >= self.size:
return inputs
diff_half = (self.size - inputs.shape[shape_idx]) / 2
padding = [0] * 2 * self.dim
padding[-2] = math.ceil(diff_half)
padding[-1] = math.floor(diff_half)
return torch.nn.functional.pad(inputs, padding, mode="constant", value=0)
class CroppingTransform(Transform):
"""
A transformer that crops a specified dimension of tensors
so that they have at most the given size.
:param dim: the dimension to be cropped (from behind, see PaddingTransform)
:param size: the size to which the dimension should be cropped if it is larger
"""
def __init__(self, dim: int, size: int):
self.dim = dim
self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Crop inputs and targets so that dimension self.dim has at
most a size of self.size.
"""
return self.crop(inputs), self.crop(targets)
def crop(self, inputs: Tensor):
"""
Crop a single input tensor so that dimension self.dim has at
most a size of self.size.
"""
shape_idx = len(inputs.shape) - self.dim
if inputs.shape[shape_idx] <= self.size:
return inputs
# create slices selecting everything up to shape_idx
slices = map(slice, inputs.shape[:shape_idx])
slices = list(slices)
# add slice which performs the crop on the specified dimension
center = inputs.shape[shape_idx] // 2
size_half = self.size / 2
slices.append(
slice(center - math.ceil(size_half), center + math.floor(size_half))
)
return inputs[slices]
class PadCropTranform(SequenceTransform):
"""
A combination of padding and cropping that makes sure that a
specified dimension always has a given size.
:param dim: the dimension to be padded and cropped (from behind, see PaddingTransform)
:param size: the size to which the dimension should be padded and cropped
"""
def __init__(self, dim: int, size: int):
super().__init__(
transforms=[PaddingTransform(dim, size), CroppingTransform(dim, size)]
)
if __name__ == "__main__":
transform = PadCropTranform(dim=3, size=32)
shape = (8, 1, 29, 128, 128)
inputs = torch.rand(shape)
targets = torch.rand(shape)
inputs, targets = transform(inputs, targets)
assert inputs.shape[2] == 32
assert targets.shape[2] == 32
print(inputs.shape)
shape = (8, 1, 45, 128, 128)
inputs = torch.rand(shape)
targets = torch.rand(shape)
inputs, targets = transform(inputs, targets)
assert inputs.shape[2] == 32
assert targets.shape[2] == 32
print(inputs.shape)