Skip to content
Snippets Groups Projects
Commit 32501532 authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip, multidim

parent f5860837
No related branches found
No related tags found
No related merge requests found
import os
from zlib import compress
import matplotlib
import matplotlib.pyplot as plt
......@@ -7,12 +8,13 @@ import numpy.typing as npt
import torch
import torch.nn.functional as F
from PIL import Image
from torch import Tensor, conv2d, nn
from scipy.fft import fft
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.utils import save_image, make_grid
from torchvision.utils import make_grid, save_image
device = torch.device("cuda")
matplotlib.use("Agg")
......@@ -400,6 +402,26 @@ def distance_measure(model: VAE, img: Tensor):
)
def compression_complexity(img: Tensor):
np_img = img[0].numpy()
compressed = compress(np_img)
return len(compressed)
def fft_measure(img: Tensor):
np_img = img[0][0].numpy()
fft = np.fft.fft2(np_img)
magnitude = np.fft.fftshift(np.abs(fft))
spectrum = np.log(1 + magnitude)
M, N = np_img.shape
total_freq_value = spectrum.sum()
inner_sum = spectrum[M // 3 : 2 * (M // 3), N // 3 : 2 * (N // 3)].sum()
return (total_freq_value - inner_sum) / total_freq_value
def complexity_measure(
model_gb: nn.Module,
model_lb: nn.Module,
......@@ -495,36 +517,6 @@ def complexity_measure_diff(
)
def alt_complexity_measure(
model_gb: nn.Module, model_lb: nn.Module, img: Tensor, epsilon=0.4
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(device)
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
bce_gb = F.binary_cross_entropy(recon_gb, mask.view(-1, 4096), reduction="sum")
bce_lb = F.binary_cross_entropy(recon_lb, mask.view(-1, 4096), reduction="sum")
recon_bits_gb = recon_gb.view(-1, 64, 64).cpu().numpy() > epsilon
recon_bits_lb = recon_lb.view(-1, 64, 64).cpu().numpy() > epsilon
mask_bits = mask.cpu().numpy() > 0
tp_gb = (mask_bits & recon_bits_gb).sum()
fp_gb = (recon_bits_gb & ~mask_bits).sum()
tp_lb = (mask_bits & recon_bits_lb).sum()
fp_lb = (recon_bits_lb & ~mask_bits).sum()
prec_gb = tp_gb / (tp_gb + fp_gb)
prec_lb = tp_lb / (tp_lb + fp_lb)
complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
return complexity
def plot_samples(masks: Tensor, complexities: npt.NDArray):
dpi = 150
rows = cols = 20
......@@ -580,6 +572,42 @@ def visualize_sort_mean(data_loader: DataLoader, model: VAE):
)
def visualize_sort_compression(data_loader: DataLoader):
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
masks[i] = mask[0]
distances[i] = compression_complexity(mask)
sort_idx = torch.argsort(distances)
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("compression length")
plt.savefig("shape_complexity/results/compression_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx])
def visualize_sort_fft(data_loader: DataLoader):
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
masks[i] = mask[0]
distances[i] = fft_measure(mask)
sort_idx = torch.argsort(distances)
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("compression length")
plt.savefig("shape_complexity/results/fft_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx])
def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
masks_recon = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
......@@ -604,6 +632,37 @@ def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
)
def visualize_sort_3dim(
data_loader: DataLoader, model_gb: nn.Module, model_lb: nn.Module
):
masks_recon = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
measures = torch.zeros((400, 3))
for i, (mask, _) in enumerate(data_loader, 0):
c_compress = compression_complexity(mask)
c_fft = fft_measure(mask)
# TODO: maybe exchange by diff measure instead of precision
c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
model_gb, model_lb, mask
)
masks_recon[i] = mask_recon_grid
masks[i] = mask[0]
measures[i] = torch.tensor([c_compress, c_fft, c_vae])
measures[:] /= measures.max(dim=0).values
measure_norm = torch.linalg.vector_norm(measures, dim=1)
sort_idx = np.argsort(np.array(measure_norm))
recon_masks_sorted = masks_recon.numpy()[sort_idx]
masks_sorted = masks.numpy()[sort_idx]
# TODO: add 3d plot of measures
return plot_samples(masks_sorted, measure_norm[sort_idx]), plot_samples(
recon_masks_sorted, measure_norm[sort_idx]
)
def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
sampler = RandomSampler(dataset, replacement=True, num_samples=400)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
......@@ -795,7 +854,7 @@ def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
LR = 1e-3
EPOCHS = 10
LOAD_PRETRAINED = False
LOAD_PRETRAINED = True
def main():
......@@ -857,6 +916,12 @@ def main():
visualize_sort_group(data_loader, models[bn_gt], models[bn_lt])
# visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
fig, _ = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt])
fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
fig = visualize_sort_fft(data_loader)
fig.savefig(f"shape_complexity/results/sort_fft.png")
fig = visualize_sort_compression(data_loader)
fig.savefig(f"shape_complexity/results/sort_compression.png")
fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment