Skip to content
Snippets Groups Projects
Commit 9daaf276 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement padding and cropping transforms

parent 5f983873
No related branches found
No related tags found
No related merge requests found
import math
from typing import List, Tuple from typing import List, Tuple
import torch
from torch import Tensor from torch import Tensor
...@@ -11,13 +12,11 @@ class Transform: ...@@ -11,13 +12,11 @@ class Transform:
used for normalization and data augmentation. used for normalization and data augmentation.
""" """
def __call__( def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
""" """
Apply the transformer to a pair of inputs and expected outputs in a dataset. Apply the transformer to a pair of inputs and expected outputs in a dataset.
""" """
return inputs, outputs_expected return inputs, targets
class SequenceTransform(Transform): class SequenceTransform(Transform):
...@@ -28,12 +27,10 @@ class SequenceTransform(Transform): ...@@ -28,12 +27,10 @@ class SequenceTransform(Transform):
def __init__(self, transforms: List[Transform]): def __init__(self, transforms: List[Transform]):
self.transforms = transforms self.transforms = transforms
def __call__( def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
for transforms in self.transforms: for transforms in self.transforms:
inputs, outputs_expected = transforms(inputs, outputs_expected) inputs, targets = transforms(inputs, targets)
return inputs, outputs_expected return inputs, targets
class ScaleTransform(Transform): class ScaleTransform(Transform):
...@@ -51,10 +48,120 @@ class ScaleTransform(Transform): ...@@ -51,10 +48,120 @@ class ScaleTransform(Transform):
self.scale_inputs = scale_inputs self.scale_inputs = scale_inputs
self.scale_outputs = scale_outputs self.scale_outputs = scale_outputs
def __call__( def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
self, inputs: Tensor, outputs_expected: Tensor
) -> Tuple[Tensor, Tensor]:
""" """
Scale the inputs and the outputs by the factors defined in the constructor. Scale the inputs and the outputs by the factors defined in the constructor.
""" """
return inputs * self.scale_inputs, outputs_expected * self.scale_outputs return inputs * self.scale_inputs, targets * self.scale_outputs
class PaddingTransform(Transform):
"""
A transformer that pads a specified dimension of tensors
so that they have at least a given size.
:param dim: the dimension to be padded (from behind, see torch.nn.functional.pad)
:param size: the size to which the dimension should be padded if it is smaller
"""
def __init__(self, dim: int, size: int):
self.dim = dim
self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Pad inputs and targets so that dimension self.dim has at
least a size of self.size.
"""
return self.pad(inputs), self.pad(targets)
def pad(self, inputs: Tensor):
"""
Pad a single input tensor so that dimension self.dim has at
least a size of self.size.
"""
shape_idx = len(inputs.shape) - self.dim
if inputs.shape[shape_idx] >= self.size:
return inputs
diff_half = (self.size - inputs.shape[shape_idx]) / 2
padding = [0] * 2 * self.dim
padding[-2] = math.ceil(diff_half)
padding[-1] = math.floor(diff_half)
return torch.nn.functional.pad(inputs, padding, mode="constant", value=0)
class CroppingTransform(Transform):
"""
A transformer that crops a specified dimension of tensors
so that they have at most the given size.
:param dim: the dimension to be cropped (from behind, see PaddingTransform)
:param size: the size to which the dimension should be cropped if it is larger
"""
def __init__(self, dim: int, size: int):
self.dim = dim
self.size = size
def __call__(self, inputs: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
"""
Crop inputs and targets so that dimension self.dim has at
most a size of self.size.
"""
return self.crop(inputs), self.crop(targets)
def crop(self, inputs: Tensor):
"""
Crop a single input tensor so that dimension self.dim has at
most a size of self.size.
"""
shape_idx = len(inputs.shape) - self.dim
if inputs.shape[shape_idx] <= self.size:
return inputs
# create slices selecting everything up to shape_idx
slices = map(slice, inputs.shape[:shape_idx])
slices = list(slices)
# add slice which performs the crop on the specified dimension
center = inputs.shape[shape_idx] // 2
size_half = self.size / 2
slices.append(
slice(center - math.ceil(size_half), center + math.floor(size_half))
)
return inputs[slices]
class PadCropTranform(SequenceTransform):
"""
A combination of padding and cropping that makes sure that a
specified dimension always has a given size.
:param dim: the dimension to be padded and cropped (from behind, see PaddingTransform)
:param size: the size to which the dimension should be padded and cropped
"""
def __init__(self, dim: int, size: int):
super().__init__(
transforms=[PaddingTransform(dim, size), CroppingTransform(dim, size)]
)
if __name__ == "__main__":
transform = PadCropTranform(dim=3, size=32)
shape = (8, 1, 29, 128, 128)
inputs = torch.rand(shape)
targets = torch.rand(shape)
inputs, targets = transform(inputs, targets)
assert inputs.shape[2] == 32
assert targets.shape[2] == 32
print(inputs.shape)
shape = (8, 1, 45, 128, 128)
inputs = torch.rand(shape)
targets = torch.rand(shape)
inputs, targets = transform(inputs, targets)
assert inputs.shape[2] == 32
assert targets.shape[2] == 32
print(inputs.shape)
...@@ -15,7 +15,7 @@ dataset = MuMapMockDataset("data/initial/") ...@@ -15,7 +15,7 @@ dataset = MuMapMockDataset("data/initial/")
model = UNet(in_channels=1, features=[8, 16]) model = UNet(in_channels=1, features=[8, 16])
device = torch.device("cpu") device = torch.device("cpu")
weights = torch.load("tmp/10.pth", map_location=device) weights = torch.load("train_data/snapshots/10.pth", map_location=device)
model.load_state_dict(weights) model.load_state_dict(weights)
model = model.eval() model = model.eval()
...@@ -24,7 +24,7 @@ recon = recon.unsqueeze(dim=0) ...@@ -24,7 +24,7 @@ recon = recon.unsqueeze(dim=0)
recon = norm_max(recon) recon = norm_max(recon)
output = model(recon) output = model(recon)
output = output * 40206.0 # output = output * 40206.0
diff = ((mu_map - output) ** 2).mean() diff = ((mu_map - output) ** 2).mean()
print(f"Diff: {diff:.3f}") print(f"Diff: {diff:.3f}")
...@@ -66,78 +66,3 @@ while True: ...@@ -66,78 +66,3 @@ while True:
i = (i + 1) % output.shape[0] i = (i + 1) % output.shape[0]
# dataset = MuMapDataset("data/initial")
# # print(" Recon || MuMap")
# # print(" Min | Max | Average || Min | Max | Average")
# r_max = []
# r_avg = []
# r_max_p = []
# r_avg_p = []
# r_avg_x = []
# m_max = []
# for recon, mu_map in dataset:
# r_max.append(recon.max())
# r_avg.append(recon.mean())
# recon_p = recon[:, :, 16:112, 16:112]
# r_max_p.append(recon_p.max())
# r_avg_p.append(recon_p.mean())
# r_avg_x.append(recon.sum() / (recon > 0.0).sum())
# # r_min = f"{recon.min():5.3f}"
# # r_max = f"{recon.max():5.3f}"
# # r_avg = f"{recon.mean():5.3f}"
# # m_min = f"{mu_map.min():5.3f}"
# # m_max = f"{mu_map.max():5.3f}"
# # m_avg = f"{mu_map.mean():5.3f}"
# # print(f"{r_min} | {r_max} | {r_avg} || {m_min} | {m_max} | {m_avg}")
# m_max.append(mu_map.max())
# # print(mu_map.max())
# r_max = np.array(r_max)
# r_avg = np.array(r_avg)
# r_max_p = np.array(r_max_p)
# r_avg_p = np.array(r_avg_p)
# r_avg_x = np.array(r_avg_x)
# m_max = np.array(m_max)
# fig, ax = plt.subplots()
# ax.scatter(r_max, m_max)
# # fig, axs = plt.subplots(4, 3, figsize=(16, 12))
# # axs[0, 0].hist(r_max)
# # axs[0, 0].set_title("Max")
# # axs[1, 0].hist(r_avg)
# # axs[1, 0].set_title("Mean")
# # axs[2, 0].hist(r_max / r_avg)
# # axs[2, 0].set_title("Max / Mean")
# # axs[3, 0].hist(recon.flatten())
# # axs[3, 0].set_title("Example Reconstruction")
# # axs[0, 1].hist(r_max_p)
# # axs[0, 1].set_title("Max")
# # axs[1, 1].hist(r_avg_p)
# # axs[1, 1].set_title("Mean")
# # axs[2, 1].hist(r_max_p / r_avg_p)
# # axs[2, 1].set_title("Max / Mean")
# # axs[3, 1].hist(recon_p.flatten())
# # axs[3, 1].set_title("Example Reconstruction")
# # axs[0, 2].hist(r_max_p)
# # axs[0, 2].set_title("Max")
# # axs[1, 2].hist(r_avg_x)
# # axs[1, 2].set_title("Mean")
# # axs[2, 2].hist(r_max_p / r_avg_x)
# # axs[2, 2].set_title("Max / Mean")
# # axs[3, 2].hist(torch.masked_select(recon, (recon > 0.0)))
# # axs[3, 2].set_title("Example Reconstruction")
# plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment