Skip to content
Snippets Groups Projects
Commit 406133a8 authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

refactor vis stuff

parent 5d4d3bb0
No related branches found
No related tags found
No related merge requests found
......@@ -2,8 +2,7 @@ import os
# from zlib import compress
from bz2 import compress
import string
from textwrap import fill
from typing import Callable
import matplotlib
import matplotlib.pyplot as plt
......@@ -18,7 +17,6 @@ from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler, Dataset
from torchvision.datasets import ImageFolder
from torchvision.io import read_image
from torchvision.transforms import transforms
from torchvision.utils import make_grid, save_image
......@@ -488,7 +486,7 @@ def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
return prec, rec, comp_data
def distance_measure(model: VAE, img: Tensor):
def distance_measure(img: Tensor, model: VAE):
model.eval()
with torch.no_grad():
......@@ -511,9 +509,9 @@ def compression_complexity(img: Tensor, fill_ratio_norm=False):
if fill_ratio_norm:
fill_ratio = np_img.sum().item() / np.ones_like(np_img).sum().item()
return len(compressed) * (1 - fill_ratio)
return len(compressed) * (1 - fill_ratio), None
return len(compressed)
return len(compressed), None
def fft_measure(img: Tensor):
......@@ -533,13 +531,13 @@ def fft_measure(img: Tensor):
mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2))
# mean frequency in range 0 to 0.5
return mean_freq / 0.5
return mean_freq / 0.5, None
def complexity_measure(
img: Tensor,
model_gb: nn.Module,
model_lb: nn.Module,
img: Tensor,
epsilon=0.4,
fill_ratio_norm=False,
):
......@@ -581,7 +579,7 @@ def complexity_measure(
)
def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4):
def mean_precision(img: Tensor, models: list[nn.Module], epsilon=0.4):
mask = img.to(device)
mask_bits = mask[0].cpu() > 0
......@@ -598,13 +596,13 @@ def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4):
prec = tp / (tp + fp)
precisions[i] = prec
return 1 - precisions.mean()
return 1 - precisions.mean(), None
def complexity_measure_diff(
img: Tensor,
model_gb: nn.Module,
model_lb: nn.Module,
img: Tensor,
):
model_gb.eval()
model_lb.eval()
......@@ -628,7 +626,7 @@ def complexity_measure_diff(
)
def plot_samples(masks: Tensor, complexities: npt.NDArray):
def plot_samples(masks: Tensor, ratings: npt.NDArray):
dpi = 150
rows = cols = 20
total = rows * cols
......@@ -645,7 +643,7 @@ def plot_samples(masks: Tensor, complexities: npt.NDArray):
plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
ax.set_title(
f"{complexities[idx]:.4f}",
f"{ratings[idx]:.4f}",
fontdict={"fontsize": 6, "color": "orange"},
y=0.35,
)
......@@ -659,109 +657,52 @@ def plot_samples(masks: Tensor, complexities: npt.NDArray):
return fig
def visualize_sort_mean(data_loader: DataLoader, model: VAE):
recon_masks = torch.zeros((400, 3, 64, 128))
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
distance, mask_recon_grid = distance_measure(model, mask)
masks[i] = mask[0]
recon_masks[i] = mask_recon_grid
distances[i] = distance
sort_idx = torch.argsort(distances)
recon_masks_sorted = recon_masks.numpy()[sort_idx]
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("latent mean L2 distance")
plt.savefig("shape_complexity/results/distance_plot.png")
return (
plot_samples(masks_sorted, distances.numpy()[sort_idx]),
plot_samples(recon_masks_sorted, distances.numpy()[sort_idx]),
)
def visualize_sort_compression(data_loader: DataLoader, fill_ratio_norm=False):
def visualize_sort(
data_loader: DataLoader,
metric_fn: Callable[[Tensor, any], tuple[torch.float32, Tensor]],
metric_name: str,
*fn_args: any,
plot_ratings=False,
**fn_kwargs: any,
):
recon_masks = None
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
masks[i] = mask[0]
distances[i] = compression_complexity(mask, fill_ratio_norm)
sort_idx = torch.argsort(distances)
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("compression length")
plt.savefig("shape_complexity/results/compression_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx])
ratings = torch.zeros((400,))
plot_recons = True
def visualize_sort_fft(data_loader: DataLoader):
masks = torch.zeros((400, 1, 64, 64))
distances = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
masks[i] = mask[0]
distances[i] = fft_measure(mask)
sort_idx = torch.argsort(distances)
masks_sorted = masks.numpy()[sort_idx]
rating, mask_recon_grid = metric_fn(mask, *fn_args, **fn_kwargs)
if plot_recons and mask_recon_grid == None:
plot_recons = False
elif plot_recons and recon_masks is None:
recon_masks = torch.zeros((400, *mask_recon_grid.shape))
plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images")
plt.ylabel("mean unidirectional frequency")
plt.savefig("shape_complexity/results/fft_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx])
def visualize_sort_mean_precision(models: list[nn.Module], data_loader: DataLoader):
masks = torch.zeros((400, 1, 64, 64))
precisions = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
masks[i] = mask[0]
precisions[i] = mean_precision(models, mask)
sort_idx = torch.argsort(precisions)
masks_sorted = masks.numpy()[sort_idx]
ratings[i] = rating
plt.plot(np.arange(len(precisions)), np.sort(precisions.numpy()))
plt.xlabel("images")
plt.ylabel("mean precision")
plt.savefig("shape_complexity/results/mean_prec_plot.png")
if plot_recons:
recon_masks[i] = mask_recon_grid
return plot_samples(masks_sorted, precisions.numpy()[sort_idx])
if plot_ratings:
plt.plot(np.arange(len(ratings)), np.sort(ratings.numpy()))
plt.xlabel("images")
plt.ylabel(f"{metric_name} rating")
plt.savefig(f"shape_complexity/results/{metric_name}_rating_plot.png")
plt.close()
sort_idx = torch.argsort(ratings)
def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
masks_recon = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
diffs = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask)
masks_recon[i] = mask_recon_grid
masks[i] = mask[0]
diffs[i] = diff
sort_idx = np.argsort(np.array(diffs))
recon_masks_sorted = masks_recon.numpy()[sort_idx]
masks_sorted = masks.numpy()[sort_idx]
fig = plot_samples(masks_sorted, ratings.numpy()[sort_idx])
fig.savefig(f"shape_complexity/results/{metric_name}_sort.png")
plt.close(fig)
plt.plot(np.arange(len(diffs)), np.sort(diffs))
plt.xlabel("images")
plt.ylabel("pixelwise difference of reconstructions")
plt.savefig("shape_complexity/results/px_diff_plot.png")
return (
plot_samples(masks_sorted, diffs[sort_idx]),
plot_samples(recon_masks_sorted, diffs[sort_idx]),
)
if plot_recons:
recon_masks_sorted = recon_masks.numpy()[sort_idx]
fig_recon = plot_samples(recon_masks_sorted, ratings.numpy()[sort_idx])
fig_recon.savefig(f"shape_complexity/results/{metric_name}_sort_recon.png")
plt.close(fig_recon)
def visualize_sort_3dim(
......@@ -771,10 +712,12 @@ def visualize_sort_3dim(
masks = torch.zeros((400, 1, 64, 64))
measures = torch.zeros((400, 3))
for i, (mask, _) in enumerate(data_loader, 0):
c_compress = compression_complexity(mask)
c_fft = fft_measure(mask)
c_compress, _ = compression_complexity(mask, fill_ratio_norm=True)
c_fft, _ = fft_measure(mask)
# TODO: maybe exchange by diff or mean measure instead of precision
c_vae, mask_recon_grid = complexity_measure(model_gb, model_lb, mask)
c_vae, mask_recon_grid = complexity_measure(
mask, model_gb, model_lb, fill_ratio_norm=True
)
masks_recon[i] = mask_recon_grid
masks[i] = mask[0]
measures[i] = torch.tensor([c_compress, c_fft, c_vae])
......@@ -803,39 +746,10 @@ def visualize_sort_3dim(
)
def visualize_sort_complexity(data_loader, model_gb: nn.Module, model_lb: nn.Module):
recon_masks = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
complexities = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
(complexity, mask_recon_grid,) = complexity_measure(
model_gb,
model_lb,
mask,
fill_ratio_norm=True,
)
recon_masks[i] = mask_recon_grid
masks[i] = mask[0]
complexities[i] = complexity
sort_idx = np.argsort(np.array(complexities))
masks_sorted = masks.numpy()[sort_idx]
recon_masks_sorted = recon_masks.numpy()[sort_idx]
fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs_recon.png")
plt.close(fig)
fig = plot_samples(masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs.png")
plt.close(fig)
LR = 1.5e-3
EPOCHS = 100
LOAD_PRETRAINED = True
# TODO: refactor to have one plot function with callable metric function
# TODO: try out pixelwise loss again (in 3d as well)
# TODO: might be a good idea to implement a bbox cut preprocessing transform thingy
......@@ -886,49 +800,34 @@ def main():
bn_gt = 32
bn_lt = 8
# for i in range(10):
# figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
# figure.savefig(
# f"shape_complexity/results/this_{bn_gt}_to_{bn_lt}_sample{i}.png"
# )
# figure.clear()
# plt.close(figure)
# figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
# figure.savefig(f"shape_complexity/results/sort_{bn_gt}_to_{bn_lt}.png")
sampler = RandomSampler(dataset, replacement=True, num_samples=400)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
visualize_sort_complexity(data_loader, models[bn_gt], models[bn_lt])
# visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
fig, fig_recon = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt])
fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png")
plt.close(fig)
plt.close(fig_recon)
fig = visualize_sort_mean_precision(list(models.values()), data_loader)
fig.savefig(f"shape_complexity/results/sort_mean_prec.png")
plt.close(fig)
fig = visualize_sort_fft(data_loader)
fig.savefig(f"shape_complexity/results/sort_fft.png")
plt.close(fig)
fig = visualize_sort_compression(data_loader)
fig.savefig(f"shape_complexity/results/sort_compression.png")
fig = visualize_sort_compression(data_loader, fill_ratio_norm=True)
fig.savefig(f"shape_complexity/results/sort_compression_fill_norm.png")
fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
plt.close(fig)
plt.close()
# fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt])
# fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png")
# fig_recon.savefig(
# f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png"
# )
# plt.close(fig)
# plt.close(fig_recon)
visualize_sort(
data_loader,
complexity_measure,
"recon_complexity",
models[bn_gt],
models[bn_lt],
fill_ratio_norm=True,
)
visualize_sort(data_loader, mean_precision, "mean_precision", list(models.values()))
visualize_sort(data_loader, fft_measure, "fft")
visualize_sort(data_loader, compression_complexity, "compression")
visualize_sort(
data_loader,
compression_complexity,
"compression_fill_norm",
fill_ratio_norm=True,
)
visualize_sort(data_loader, distance_measure, "latent_l2_distance", models[bn_gt])
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment