-
Markus Rothgänger authoredMarkus Rothgänger authored
shape_complexity.py 29.32 KiB
import os
from zlib import compress
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
import torch.nn.functional as F
from PIL import Image
from scipy.fft import fft
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.utils import make_grid, save_image
device = torch.device("cuda")
matplotlib.use("Agg")
dx = [+1, 0, -1, 0]
dy = [0, +1, 0, -1]
# perform depth first search for each candidate/unlabeled region
# reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation
def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int):
n_rows, n_cols = mask.shape
if x < 0 or x == n_rows:
return
if y < 0 or y == n_cols:
return
if labels[x][y] or not mask[x][y]:
return # already labeled or not marked with 1 in image
# mark the current cell
labels[x][y] = current_label
# recursively mark the neighbors
for direction in range(4):
dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)
def find_components(mask: npt.NDArray):
label = 0
n_rows, n_cols = mask.shape
labels = np.zeros(mask.shape, dtype=np.int8)
for i in range(n_rows):
for j in range(n_cols):
if not labels[i][j] and mask[i][j]:
label += 1
dfs(mask, i, j, labels, label)
return labels
# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
def bbox(img):
max_x, max_y = img.shape
rows = np.any(img, axis=1)
cols = np.any(img, axis=0)
rmin, rmax = np.where(rows)[0][[0, -1]]
cmin, cmax = np.where(cols)[0][[0, -1]]
rmin = rmin - 1 if rmin > 0 else rmin
cmin = cmin - 1 if cmin > 0 else cmin
rmax = rmax + 1 if rmax < max_x else rmax
cmax = cmax + 1 if cmax < max_y else cmax
return rmin, rmax, cmin, cmax
def extract_single_masks(labels: npt.NDArray):
masks = []
for l in range(labels.max() + 1):
mask = (labels == l).astype(np.int8)
rmin, rmax, cmin, cmax = bbox(mask)
masks.append(mask[rmin : rmax + 1, cmin : cmax + 1])
return masks
class VAE(nn.Module):
"""
https://github.com/pytorch/examples/blob/main/vae/main.py
"""
def __init__(self, bottleneck=2, image_dim=4096):
super(VAE, self).__init__()
self.bottleneck = bottleneck
self.image_dim = image_dim
self.prelim_encode = nn.Sequential(
nn.Flatten(), nn.Linear(image_dim, 400), nn.ReLU()
)
self.encode_mu = nn.Sequential(nn.Linear(400, bottleneck))
self.encode_logvar = nn.Sequential(nn.Linear(400, bottleneck))
self.decode = nn.Sequential(
nn.Linear(bottleneck, 400),
nn.ReLU(),
nn.Linear(400, image_dim),
nn.Sigmoid(),
)
def encode(self, x):
# h1 = F.relu(self.encode(x))
# return self.encode_mu(h1), self.encode_logvar(h1)
x = self.prelim_encode(x)
return self.encode_mu(x), self.encode_logvar(x)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# Reconstruction + KL divergence losses summed over all elements and batch
def loss(self, recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum")
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
class CONVVAE(nn.Module):
def __init__(
self,
bottleneck=2,
):
super(CONVVAE, self).__init__()
self.bottleneck = bottleneck
self.feature_dim = 6 * 6 * 64
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 5),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 30x30x16
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 3),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 14x14x32
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 3),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 6x6x64
)
# self.conv4 = nn.Sequential(
# nn.Conv2d(32, self.bottleneck, 5),
# nn.ReLU(),
# nn.MaxPool2d((2, 2), return_indices=True), # -> 1x1xbottleneck
# )
self.encode_mu = nn.Sequential(
nn.Flatten(),
nn.Linear(self.feature_dim, self.bottleneck),
)
self.encode_logvar = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
)
self.decode_linear = nn.Linear(self.bottleneck, self.feature_dim)
# self.decode4 = nn.Sequential(
# nn.ConvTranspose2d(self.bottleneck, 32, 5),
# nn.ReLU(),
# )
self.decode3 = nn.Sequential(
nn.ConvTranspose2d(64, 32, 3),
nn.ReLU(),
)
self.decode2 = nn.Sequential(
nn.ConvTranspose2d(32, 16, 3),
nn.ReLU(),
)
self.decode1 = nn.Sequential(
nn.ConvTranspose2d(16, 1, 5),
nn.Sigmoid(),
)
def encode(self, x):
x, idx1 = self.conv1(x)
x, idx2 = self.conv2(x)
x, idx3 = self.conv3(x)
# x, idx4 = self.conv4(x)
mu = self.encode_mu(x)
logvar = self.encode_logvar(x)
return mu, logvar, (idx1, idx2, idx3)
def decode(self, z: Tensor, indexes: tuple):
(idx1, idx2, idx3) = indexes
z = self.decode_linear(z)
z = z.view((-1, 64, 6, 6))
# z = F.max_unpool2d(z, idx4, (2, 2))
# z = self.decode4(z)
z = F.max_unpool2d(z, idx3, (2, 2))
z = self.decode3(z)
z = F.max_unpool2d(z, idx2, (2, 2))
z = self.decode2(z)
z = F.max_unpool2d(z, idx1, (2, 2))
z = self.decode1(z)
# z = z.view(-1, 128, 1, 1)
# return self.decode_conv(z)
return z
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar, indexes = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z, indexes), mu, logvar
def loss(self, recon_x, x, mu, logvar):
"""https://github.com/pytorch/examples/blob/main/vae/main.py"""
BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def load_data():
transform = transforms.Compose(
[
transforms.Grayscale(),
transforms.Resize(
(64, 64), interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
]
)
trajectories = [
# "v3_subtle_iceberg_lettuce_nymph-6_203-2056",
"v3_absolute_grape_changeling-16_2277-4441",
"v3_content_squash_angel-3_16074-17640",
"v3_smooth_kale_loch_ness_monster-1_4439-6272",
"v3_cute_breadfruit_spirit-6_17090-19102",
"v3_key_nectarine_spirit-2_7081-9747",
"v3_subtle_iceberg_lettuce_nymph-6_3819-6049",
"v3_juvenile_apple_angel-30_396415-398113",
"v3_subtle_iceberg_lettuce_nymph-6_6100-8068",
]
datasets = []
for trj in trajectories:
datasets.append(
ImageFolder(
f"activation_vis/out/critic/masks/{trj}/0/4", transform=transform
)
)
dataset = torch.utils.data.ConcatDataset(datasets)
data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
return data_loader, dataset
def train(epoch, model: VAE or CONVVAE, optimizer, data_loader, log_interval=40):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(data_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = model.loss(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
if batch_idx % log_interval == 0:
print(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
epoch,
batch_idx * len(data),
len(data_loader.dataset),
100.0 * batch_idx / len(data_loader),
loss.item() / len(data),
)
)
print(
"====> Epoch: {} Average loss: {:.4f}".format(
epoch, train_loss / len(data_loader.dataset)
)
)
def test(epoch, models, dataset):
for model in models:
model.eval()
test_loss = [0 for _ in models]
test_batch_size = 32
sampler = RandomSampler(dataset, replacement=True, num_samples=64)
test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=sampler)
comp_data = None
with torch.no_grad():
for i, (data, _) in enumerate(test_loader):
data = data.to(device)
for j, model in enumerate(models):
recon_batch, mu, logvar = model(data)
test_loss[j] += model.loss(recon_batch, data, mu, logvar).item()
if i == 0:
n = min(data.size(0), 20)
if comp_data == None:
comp_data = data[:n]
comp_data = torch.cat(
[comp_data, recon_batch.view(test_batch_size, 1, 64, 64)[:n]]
)
if i == 0:
if not os.path.exists("results"):
os.makedirs("results")
save_image(
comp_data.cpu(),
"results/reconstruction_" + str(epoch) + ".png",
nrow=min(data.size(0), 20),
)
for i, model in enumerate(models):
test_loss[i] /= len(test_loader.dataset)
print(f"====> Test set loss model {i}: {test_loss[i]:.4f}")
def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
model.eval()
image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path)))
labels = find_components(image[0])
single_masks = extract_single_masks(labels)
mask = transforms.F.to_tensor(
transforms.F.resize(
transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)),
(64, 64),
)
)
with torch.no_grad():
mask = mask.to(device)
recon_x, _, _ = model(mask)
recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon
mask_bits = mask.cpu().numpy() > 0
TP = (mask_bits & recon_bits).sum()
FP = (recon_bits & ~mask_bits).sum()
FN = (mask_bits & ~recon_bits).sum()
prec = TP / (TP + FP)
rec = TP / (TP + FN)
# loss = pixelwise_loss(recon_x, mask)
comp_data = torch.cat(
[mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)]
)
# print(f"mask loss: {loss:.4f}")
return prec, rec, comp_data
def distance_measure(model: VAE, img: Tensor):
model.eval()
with torch.no_grad():
mask = img.to(device)
recon, mean, _ = model(mask)
# TODO: apply threshold here?!
_, recon_mean, _ = model(recon)
distance = torch.norm(mean - recon_mean, p=2)
return distance, make_grid(
torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0
)
def compression_complexity(img: Tensor):
np_img = img[0].numpy()
compressed = compress(np_img)
return len(compressed)
def fft_measure(img: Tensor):
np_img = img[0][0].numpy()
fft = np.fft.fft2(np_img)
magnitude = np.fft.fftshift(np.abs(fft))
spectrum = np.log(1 + magnitude)
M, N = np_img.shape
total_freq_value = spectrum.sum()
inner_sum = spectrum[M // 3 : 2 * (M // 3), N // 3 : 2 * (N // 3)].sum()
return (total_freq_value - inner_sum) / total_freq_value
def complexity_measure(
model_gb: nn.Module,
model_lb: nn.Module,
img: Tensor,
epsilon=0.4,
save_preliminary=False,
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(device)
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
mask_bits = mask[0].cpu() > 0
if save_preliminary:
save_image(
torch.stack(
[mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()]
).cpu(),
f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
)
save_image(
torch.stack(
[
(mask_bits & recon_bits_gb).float(),
(recon_bits_gb & ~mask_bits).float(),
(mask_bits & recon_bits_lb).float(),
(recon_bits_lb & ~mask_bits).float(),
]
).cpu(),
f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
)
tp_gb = (mask_bits & recon_bits_gb).sum()
fp_gb = (recon_bits_gb & ~mask_bits).sum()
tp_lb = (mask_bits & recon_bits_lb).sum()
fp_lb = (recon_bits_lb & ~mask_bits).sum()
prec_gb = tp_gb / (tp_gb + fp_gb)
prec_lb = tp_lb / (tp_lb + fp_lb)
complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
complexity_lb = 1 - prec_lb
complexity_gb = 1 - prec_gb
# 1 - (0.4 - abs(0.4 - 0.7)) = 0.9
# 1 - 0.7 = 0.3
return (
complexity,
complexity_lb,
complexity_gb,
prec_gb - prec_lb,
prec_lb,
prec_gb,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
).cpu(),
nrow=3,
padding=0,
),
)
def complexity_measure_diff(
model_gb: nn.Module,
model_lb: nn.Module,
img: Tensor,
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(device)
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
diff = torch.abs((recon_gb - recon_lb).cpu().sum())
return (
diff,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
).cpu(),
nrow=3,
padding=0,
),
)
def plot_samples(masks: Tensor, complexities: npt.NDArray):
dpi = 150
rows = cols = 20
total = rows * cols
n_samples, _, y, x = masks.shape
extent = (0, x - 1, 0, y - 1)
if total != n_samples:
raise Exception("shape mismatch")
fig = plt.figure(figsize=(32, 16), dpi=dpi)
for idx in np.arange(n_samples):
ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])
plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
ax.set_title(
f"{complexities[idx]:.4f}",
fontdict={"fontsize": 6, "color": "orange"},
y=0.35,
)
fig.patch.set_facecolor("#292929")
height_px = y * rows
width_px = x * cols
fig.set_size_inches(width_px / (dpi / 2), height_px / (dpi / 2), forward=True)
fig.tight_layout(pad=0)
return fig
def visualize_sort_mean(data_loader: DataLoader, model: VAE):
recon_masks = torch.zeros((400, 3, 64, 128))
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
distance, mask_recon_grid = distance_measure(model, mask)
masks[i] = mask[0]
recon_masks[i] = mask_recon_grid
distances[i] = distance
sort_idx = torch.argsort(distances)
recon_masks_sorted = recon_masks.numpy()[sort_idx]
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("latent mean L2 distance")
plt.savefig("shape_complexity/results/distance_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx]), plot_samples(
recon_masks_sorted, distances.numpy()[sort_idx]
)
def visualize_sort_compression(data_loader: DataLoader):
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
masks[i] = mask[0]
distances[i] = compression_complexity(mask)
sort_idx = torch.argsort(distances)
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("compression length")
plt.savefig("shape_complexity/results/compression_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx])
def visualize_sort_fft(data_loader: DataLoader):
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
masks[i] = mask[0]
distances[i] = fft_measure(mask)
sort_idx = torch.argsort(distances)
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("compression length")
plt.savefig("shape_complexity/results/fft_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx])
def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
masks_recon = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
diffs = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask)
masks_recon[i] = mask_recon_grid
masks[i] = mask[0]
diffs[i] = diff
sort_idx = np.argsort(np.array(diffs))
recon_masks_sorted = masks_recon.numpy()[sort_idx]
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(diffs)), np.sort(diffs))
plt.xlabel("images")
plt.ylabel("pixelwise difference of reconstructions")
plt.savefig("shape_complexity/results/px_diff_plot.png")
return plot_samples(masks_sorted, diffs[sort_idx]), plot_samples(
recon_masks_sorted, diffs[sort_idx]
)
def visualize_sort_3dim(
data_loader: DataLoader, model_gb: nn.Module, model_lb: nn.Module
):
masks_recon = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
measures = torch.zeros((400, 3))
for i, (mask, _) in enumerate(data_loader, 0):
c_compress = compression_complexity(mask)
c_fft = fft_measure(mask)
# TODO: maybe exchange by diff measure instead of precision
c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
model_gb, model_lb, mask
)
masks_recon[i] = mask_recon_grid
masks[i] = mask[0]
measures[i] = torch.tensor([c_compress, c_fft, c_vae])
measures[:] /= measures.max(dim=0).values
measure_norm = torch.linalg.vector_norm(measures, dim=1)
sort_idx = np.argsort(np.array(measure_norm))
recon_masks_sorted = masks_recon.numpy()[sort_idx]
masks_sorted = masks.numpy()[sort_idx]
# TODO: add 3d plot of measures
return plot_samples(masks_sorted, measure_norm[sort_idx]), plot_samples(
recon_masks_sorted, measure_norm[sort_idx]
)
def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
sampler = RandomSampler(dataset, replacement=True, num_samples=400)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
masks = torch.zeros((400, 3, 64, 192))
complexities = torch.zeros((400,))
diffs = []
for i, (mask, _) in enumerate(data_loader, 0):
complexity, _, _, diff, mask_recon_grid = complexity_measure(
model_gb, model_lb, mask, save_preliminary=True
)
masks[i] = mask_recon_grid
diffs.append(diff)
complexities[i] = complexity
sort_idx = np.argsort(np.array(complexities))
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(diffs)), np.sort(diffs))
plt.xlabel("images")
plt.ylabel("prec difference (L-H)")
plt.savefig("shape_complexity/results/diff_plot.png")
return plot_samples(masks_sorted, complexities[sort_idx])
def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
masks = torch.zeros((400, 3, 64, 192))
complexities = torch.zeros((400,))
complexities_lb = torch.zeros((400,))
complexities_gb = torch.zeros((400,))
diffs = []
prec_lbs = []
prec_gbs = []
for i, (mask, _) in enumerate(data_loader, 0):
(
complexity,
lb,
gb,
diff,
prec_lb,
prec_gb,
mask_recon_grid,
) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
masks[i] = mask_recon_grid
diffs.append(diff)
prec_lbs.append(prec_lb)
prec_gbs.append(prec_gb)
complexities[i] = complexity
complexities_lb[i] = lb
complexities_gb[i] = gb
sort_idx = np.argsort(np.array(complexities))
sort_idx_lb = np.argsort(np.array(complexities_lb))
sort_idx_gb = np.argsort(np.array(complexities_gb))
masks_sorted = masks.numpy()[sort_idx]
masks_sorted_lb = masks.numpy()[sort_idx_lb]
masks_sorted_gb = masks.numpy()[sort_idx_gb]
diff_sort_idx = np.argsort(diffs)
# plt.savefig("shape_complexity/results/diff_plot.png")
# plt.clf
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(
np.arange(len(prec_lbs)),
np.array(prec_lbs)[diff_sort_idx],
label=f"bottleneck {model_lb.bottleneck}",
)
ax1.plot(
np.arange(len(prec_gbs)),
np.array(prec_gbs)[diff_sort_idx],
label=f"bottleneck {model_gb.bottleneck}",
)
ax2.plot(
np.arange(len(diffs)),
np.sort(diffs),
color="red",
label="prec difference (H - L)",
)
ax1.legend(loc="lower left")
ax2.legend(loc="lower right")
ax1.set_ylabel("precision")
ax2.set_ylabel("prec difference (H-L)")
plt.savefig("shape_complexity/results/prec_plot.png")
plt.clf()
fig = plot_samples(masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs.png")
plt.close(fig)
fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb])
fig.savefig("shape_complexity/results/lb.png")
plt.close(fig)
fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb])
fig.savefig("shape_complexity/results/gb.png")
plt.close(fig)
def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
recon_masks = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
complexities = torch.zeros((400,))
diffs = np.zeros((400,))
prec_gbs = np.zeros((400,))
prec_lbs = np.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
(
complexity,
_,
_,
diff,
prec_lb,
prec_gb,
mask_recon_grid,
) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
recon_masks[i] = mask_recon_grid
masks[i] = mask[0]
diffs[i] = diff
prec_gbs[i] = prec_gb
prec_lbs[i] = prec_lb
complexities[i] = complexity
sort_idx = np.argsort(np.array(complexities))
masks_sorted = masks.numpy()[sort_idx]
recon_masks_sorted = recon_masks.numpy()[sort_idx]
group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
bin_edges = [-np.inf, 0.0, 0.05, np.inf]
bins = np.digitize(diffs, bins=bin_edges, right=True)
for i in range(bins.min(), bins.max() + 1):
bin_idx = bins == i
binned_prec_gb = prec_gbs[bin_idx]
prec_mean = binned_prec_gb.mean()
prec_idx = prec_gbs > prec_mean
binned_masks_high = recon_masks[bin_idx & prec_idx]
binned_masks_low = recon_masks[bin_idx & ~prec_idx]
save_image(
binned_masks_high,
f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
padding=10,
)
save_image(
binned_masks_low,
f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
padding=10,
)
diff_sort_idx = np.argsort(diffs)
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(
np.arange(len(prec_lbs)),
np.array(prec_lbs)[diff_sort_idx],
label=f"bottleneck {model_lb.bottleneck}",
)
ax1.plot(
np.arange(len(prec_gbs)),
np.array(prec_gbs)[diff_sort_idx],
label=f"bottleneck {model_gb.bottleneck}",
)
ax2.plot(
np.arange(len(diffs)),
np.sort(diffs),
color="red",
label="prec difference (H - L)",
)
ax1.legend(loc="lower left")
ax2.legend(loc="lower right")
ax1.set_ylabel("precision")
ax2.set_ylabel("prec difference (H-L)")
ax1.set_xlabel("image")
plt.savefig("shape_complexity/results/prec_plot.png")
plt.tight_layout(pad=2)
plt.clf()
fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs_recon.png")
plt.close(fig)
fig = plot_samples(masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs.png")
plt.close(fig)
LR = 1e-3
EPOCHS = 10
LOAD_PRETRAINED = True
def main():
bottlenecks = [1, 16]
models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}
data_loader, dataset = load_data()
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
dataset, [train_size, test_size]
)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
if LOAD_PRETRAINED:
for i, model in models.items():
model.load_state_dict(
torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth")
)
else:
for epoch in range(EPOCHS):
for i, model in models.items():
train(
epoch,
model=model,
optimizer=optimizers[i],
data_loader=train_loader,
)
test(epoch, models=list(models.values()), dataset=test_dataset)
for bn in bottlenecks:
if not os.path.exists("shape_complexity/trained"):
os.makedirs("shape_complexity/trained")
torch.save(
models[bn].state_dict(),
f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth",
)
bn_gt = 16
bn_lt = 1
# for i in range(10):
# figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
# figure.savefig(
# f"shape_complexity/results/this_{bn_gt}_to_{bn_lt}_sample{i}.png"
# )
# figure.clear()
# plt.close(figure)
# figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
# figure.savefig(f"shape_complexity/results/sort_{bn_gt}_to_{bn_lt}.png")
sampler = RandomSampler(dataset, replacement=True, num_samples=400)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
visualize_sort_group(data_loader, models[bn_gt], models[bn_lt])
# visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
fig, _ = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt])
fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
fig = visualize_sort_fft(data_loader)
fig.savefig(f"shape_complexity/results/sort_fft.png")
fig = visualize_sort_compression(data_loader)
fig.savefig(f"shape_complexity/results/sort_compression.png")
fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
plt.close(fig)
plt.close(fig_recon)
fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt])
fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png")
fig_recon.savefig(
f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png"
)
plt.close(fig)
plt.close(fig_recon)
if __name__ == "__main__":
main()