Skip to content
Snippets Groups Projects
Commit d8cb6252 authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent c2dd93b0
No related branches found
No related tags found
No related merge requests found
import os
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
......@@ -14,6 +15,7 @@ from torchvision.transforms import transforms
from torchvision.utils import save_image, make_grid
device = torch.device("cuda")
matplotlib.use("Agg")
dx = [+1, 0, -1, 0]
dy = [0, +1, 0, -1]
......@@ -119,18 +121,81 @@ class VAE(nn.Module):
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
# Reconstruction + KL divergence losses summed over all elements and batch
def loss(self, recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum")
# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum")
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
return BCE + KLD
class CONVVAE(nn.Module):
def __init__(
self,
bottleneck=2,
):
super(CONVVAE, self).__init__()
self.bottleneck = bottleneck
self.feature_dim = 32 * 56 * 56
self.conv = nn.Sequential(
nn.Conv2d(1, 16, 5), nn.ReLU(), nn.Conv2d(16, 32, 5), nn.ReLU()
)
self.encode_mu = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
)
self.encode_logvar = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
)
self.decode_linear = nn.Sequential(
nn.Linear(self.bottleneck, self.feature_dim),
nn.ReLU(),
)
self.decode_conv = nn.Sequential(
nn.ConvTranspose2d(32, 16, 5),
nn.ReLU(),
nn.ConvTranspose2d(16, 1, 5),
nn.Sigmoid(),
)
def encode(self, x):
x = self.conv(x)
return self.encode_mu(x), self.encode_logvar(x)
def decode(self, z):
z = self.decode_linear(z)
z = z.view(-1, 32, 56, 56)
return self.decode_conv(z)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
def loss(self, recon_x, x, mu, logvar):
"""https://github.com/pytorch/examples/blob/main/vae/main.py"""
BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
# see Appendix B from VAE paper:
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def load_data():
......@@ -172,7 +237,7 @@ def load_data():
return data_loader, dataset
def train(epoch, model, optimizer, data_loader, log_interval=40):
def train(epoch, model: VAE or CONVVAE, optimizer, data_loader, log_interval=40):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(data_loader):
......@@ -180,7 +245,7 @@ def train(epoch, model, optimizer, data_loader, log_interval=40):
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss = loss_function(recon_batch, data, mu, logvar)
loss = model.loss(recon_batch, data, mu, logvar)
loss.backward()
train_loss += loss.item()
optimizer.step()
......@@ -219,7 +284,7 @@ def test(epoch, models, dataset):
for j, model in enumerate(models):
recon_batch, mu, logvar = model(data)
test_loss[j] += loss_function(recon_batch, data, mu, logvar).item()
test_loss[j] += model.loss(recon_batch, data, mu, logvar).item()
if i == 0:
n = min(data.size(0), 20)
......@@ -282,12 +347,13 @@ def distance_measure(model: VAE, img: Tensor):
with torch.no_grad():
mask = img.to(device)
recon, mean, _ = model(mask)
_, recon_mean, _ = model(recon.view(-1, 64, 64))
# TODO: apply threshold here?!
_, recon_mean, _ = model(recon)
distance = torch.norm(mean - recon_mean, p=2)
return distance, make_grid(
torch.stack([mask[0], recon.view(-1, 64, 64)]).cpu(), nrow=2, padding=0
torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0
)
......@@ -359,6 +425,33 @@ def complexity_measure(
)
def complexity_measure_diff(
model_gb: nn.Module,
model_lb: nn.Module,
img: Tensor,
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(device)
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
diff = torch.abs((recon_gb - recon_lb).cpu().sum())
return (
diff,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
).cpu(),
nrow=3,
padding=0,
),
)
def alt_complexity_measure(
model_gb: nn.Module, model_lb: nn.Module, img: Tensor, epsilon=0.4
):
......@@ -421,20 +514,51 @@ def plot_samples(masks: Tensor, complexities: npt.NDArray):
def visualize_sort_mean(data_loader: DataLoader, model: VAE):
masks = torch.zeros((400, 3, 64, 128))
recon_masks = torch.zeros((400, 3, 64, 128))
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
distance, mask_recon_grid = distance_measure(model, mask)
masks[i] = mask_recon_grid
masks[i] = mask[0]
recon_masks[i] = mask_recon_grid
distances[i] = distance
sort_idx = torch.argsort(distances)
recon_masks_sorted = recon_masks.numpy()[sort_idx]
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("latent mean L2 distance")
plt.savefig("shape_complexity/results/distance_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx])
return plot_samples(masks_sorted, distances.numpy()[sort_idx]), plot_samples(
recon_masks_sorted, distances.numpy()[sort_idx]
)
def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
masks_recon = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
diffs = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask)
masks_recon[i] = mask_recon_grid
masks[i] = mask[0]
diffs[i] = diff
sort_idx = np.argsort(np.array(diffs))
recon_masks_sorted = masks_recon.numpy()[sort_idx]
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(diffs)), np.sort(diffs))
plt.xlabel("images")
plt.ylabel("pixelwise difference of reconstructions")
plt.savefig("shape_complexity/results/px_diff_plot.png")
return plot_samples(masks_sorted, diffs[sort_idx]), plot_samples(
recon_masks_sorted, diffs[sort_idx]
)
def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
......@@ -456,6 +580,8 @@ def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(diffs)), np.sort(diffs))
plt.xlabel("images")
plt.ylabel("prec difference (L-H)")
plt.savefig("shape_complexity/results/diff_plot.png")
return plot_samples(masks_sorted, complexities[sort_idx])
......@@ -500,13 +626,26 @@ def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(np.arange(len(prec_lbs)), np.array(prec_lbs)[diff_sort_idx], label="lower")
ax1.plot(
np.arange(len(prec_gbs)), np.array(prec_gbs)[diff_sort_idx], label="higher"
np.arange(len(prec_lbs)),
np.array(prec_lbs)[diff_sort_idx],
label=f"bottleneck {model_lb.bottleneck}",
)
ax1.plot(
np.arange(len(prec_gbs)),
np.array(prec_gbs)[diff_sort_idx],
label=f"bottleneck {model_gb.bottleneck}",
)
ax2.plot(
np.arange(len(diffs)),
np.sort(diffs),
color="red",
label="prec difference (H - L)",
)
ax2.plot(np.arange(len(diffs)), np.sort(diffs), color="red")
ax1.legend()
ax2.legend()
ax1.legend(loc="lower left")
ax2.legend(loc="lower right")
ax1.set_ylabel("precision")
ax2.set_ylabel("prec difference (H-L)")
plt.savefig("shape_complexity/results/prec_plot.png")
plt.clf()
......@@ -520,7 +659,95 @@ def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
fig.savefig("shape_complexity/results/gb.png")
plt.close(fig)
# return plot_samples(masks_sorted, complexities[sort_idx])
def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
recon_masks = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
complexities = torch.zeros((400,))
diffs = np.zeros((400,))
prec_gbs = np.zeros((400,))
prec_lbs = np.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
(
complexity,
_,
_,
diff,
prec_lb,
prec_gb,
mask_recon_grid,
) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
recon_masks[i] = mask_recon_grid
masks[i] = mask[0]
diffs[i] = diff
prec_gbs[i] = prec_gb
prec_lbs[i] = prec_lb
complexities[i] = complexity
sort_idx = np.argsort(np.array(complexities))
masks_sorted = masks.numpy()[sort_idx]
recon_masks_sorted = recon_masks.numpy()[sort_idx]
group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
bin_edges = [-np.inf, 0.0, 0.05, np.inf]
bins = np.digitize(diffs, bins=bin_edges, right=True)
for i in range(bins.min(), bins.max() + 1):
bin_idx = bins == i
binned_prec_gb = prec_gbs[bin_idx]
prec_mean = binned_prec_gb.mean()
prec_idx = prec_gbs > prec_mean
binned_masks_high = recon_masks[bin_idx & prec_idx]
binned_masks_low = recon_masks[bin_idx & ~prec_idx]
save_image(
binned_masks_high,
f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
padding=10,
)
save_image(
binned_masks_low,
f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
padding=10,
)
diff_sort_idx = np.argsort(diffs)
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(
np.arange(len(prec_lbs)),
np.array(prec_lbs)[diff_sort_idx],
label=f"bottleneck {model_lb.bottleneck}",
)
ax1.plot(
np.arange(len(prec_gbs)),
np.array(prec_gbs)[diff_sort_idx],
label=f"bottleneck {model_gb.bottleneck}",
)
ax2.plot(
np.arange(len(diffs)),
np.sort(diffs),
color="red",
label="prec difference (H - L)",
)
ax1.legend(loc="lower left")
ax2.legend(loc="lower right")
ax1.set_ylabel("precision")
ax2.set_ylabel("prec difference (H-L)")
ax1.set_xlabel("image")
plt.savefig("shape_complexity/results/prec_plot.png")
plt.tight_layout(pad=2)
plt.clf()
fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs_recon.png")
plt.close(fig)
fig = plot_samples(masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs.png")
plt.close(fig)
LR = 1e-3
......@@ -530,7 +757,7 @@ LOAD_PRETRAINED = True
def main():
bottlenecks = [2, 4, 8, 16]
models = {i: VAE(bottleneck=i).to(device) for i in bottlenecks}
models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}
data_loader, dataset = load_data()
......@@ -538,7 +765,7 @@ def main():
if LOAD_PRETRAINED:
for i, model in models.items():
model.load_state_dict(
torch.load(f"shape_complexity/trained/VAE_{i}_split_data.pth")
torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth")
)
else:
for epoch in range(EPOCHS):
......@@ -550,10 +777,13 @@ def main():
test(epoch, models=list(models.values()), dataset=dataset)
for bn in bottlenecks:
if not os.path.exists("trained"):
os.makedirs("trained")
if not os.path.exists("shape_complexity/trained"):
os.makedirs("shape_complexity/trained")
torch.save(models[bn].state_dict(), f"trained/VAE_{bn}_split_data.pth")
torch.save(
models[bn].state_dict(),
f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth",
)
bn_gt = 16
bn_lt = 8
......@@ -572,10 +802,20 @@ def main():
sampler = RandomSampler(dataset, replacement=True, num_samples=400)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
fig = visualize_sort_mean(data_loader, models[bn_gt])
visualize_sort_group(data_loader, models[bn_gt], models[bn_lt])
# visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
plt.close(fig)
plt.close(fig_recon)
fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt])
fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png")
fig_recon.savefig(
f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png"
)
plt.close(fig)
plt.close(fig_recon)
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment