Newer
Older
from torch import Tensor
class Transform:
"""
Interface of a transformer. A transformer can be initialized and then applied to
an input tensor and expected output tensor as returned by a dataset. It can be
used for normalization and data augmentation.
"""
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Apply the transformer to a pair of inputs and expected outputs in a dataset.
"""
"""
A transformer that applies a sequence of transformers sequentially.
"""
def __init__(self, transforms: List[Transform]):
self.transforms = transforms
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
inputs, targets = transforms(inputs, targets)
return inputs, targets
class ScaleTransform(Transform):
"""
A transformer that scales the inputs and outputs by pre-defined factors.
"""
def __init__(self, scale_inputs: float = 1.0, scale_outputs: float = 1.0):
"""
Initialize a scale transformer.
:param scale_inputs: the scale multiplied to the inputs
:param scale_outputs: the scale multiplied to the outputs
"""
self.scale_inputs = scale_inputs
self.scale_outputs = scale_outputs
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Scale the inputs and the outputs by the factors defined in the constructor.
"""
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
return inputs * self.scale_inputs, targets * self.scale_outputs
class PaddingTransform(Transform):
"""
A transformer that pads a specified dimension of tensors
so that they have at least a given size.
:param dim: the dimension to be padded (from behind, see torch.nn.functional.pad)
:param size: the size to which the dimension should be padded if it is smaller
"""
def __init__(self, dim: int, size: int):
self.dim = dim
self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Pad inputs and targets so that dimension self.dim has at
least a size of self.size.
"""
return self.pad(inputs), self.pad(targets)
def pad(self, inputs: Tensor):
"""
Pad a single input tensor so that dimension self.dim has at
least a size of self.size.
"""
shape_idx = len(inputs.shape) - self.dim
if inputs.shape[shape_idx] >= self.size:
return inputs
diff_half = (self.size - inputs.shape[shape_idx]) / 2
padding = [0] * 2 * self.dim
padding[-2] = math.ceil(diff_half)
padding[-1] = math.floor(diff_half)
return torch.nn.functional.pad(inputs, padding, mode="constant", value=0)
class CroppingTransform(Transform):
"""
A transformer that crops a specified dimension of tensors
so that they have at most the given size.
:param dim: the dimension to be cropped (from behind, see PaddingTransform)
:param size: the size to which the dimension should be cropped if it is larger
"""
def __init__(self, dim: int, size: int):
self.dim = dim
self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Crop inputs and targets so that dimension self.dim has at
most a size of self.size.
"""
return self.crop(inputs), self.crop(targets)
def crop(self, inputs: Tensor):
"""
Crop a single input tensor so that dimension self.dim has at
most a size of self.size.
"""
shape_idx = len(inputs.shape) - self.dim
if inputs.shape[shape_idx] <= self.size:
return inputs
# create slices selecting everything up to shape_idx
slices = map(slice, inputs.shape[:shape_idx])
slices = list(slices)
# add slice which performs the crop on the specified dimension
center = inputs.shape[shape_idx] // 2
size_half = self.size / 2
slices.append(
slice(center - math.ceil(size_half), center + math.floor(size_half))
)
return inputs[slices]
class PadCropTranform(SequenceTransform):
"""
A combination of padding and cropping that makes sure that a
specified dimension always has a given size.
:param dim: the dimension to be padded and cropped (from behind, see PaddingTransform)
:param size: the size to which the dimension should be padded and cropped
"""
def __init__(self, dim: int, size: int):
super().__init__(
transforms=[PaddingTransform(dim, size), CroppingTransform(dim, size)]
)
if __name__ == "__main__":
transform = PadCropTranform(dim=3, size=32)
shape = (8, 1, 29, 128, 128)
inputs = torch.rand(shape)
targets = torch.rand(shape)
inputs, targets = transform(inputs, targets)
assert inputs.shape[2] == 32
assert targets.shape[2] == 32
print(inputs.shape)
shape = (8, 1, 45, 128, 128)
inputs = torch.rand(shape)
targets = torch.rand(shape)
inputs, targets = transform(inputs, targets)
assert inputs.shape[2] == 32
assert targets.shape[2] == 32
print(inputs.shape)