Skip to content
Snippets Groups Projects
Commit d4d9077e authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent 5be9343c
No related branches found
No related tags found
No related merge requests found
......@@ -29,9 +29,10 @@ def l2_distance_measure(img: Tensor, model: CONVVAE):
def compression_measure(img: Tensor, fill_ratio_norm=False):
np_img = img[0].numpy()
compressed = compress(np_img)
np_img_bytes = np_img.tobytes()
compressed = compress(np_img_bytes)
complexity = len(compressed) / len(np_img.tobytes())
complexity = len(compressed) / len(np_img_bytes)
if fill_ratio_norm:
fill_ratio = np_img.sum().item() / np.ones_like(np_img).sum().item()
......@@ -71,7 +72,6 @@ def pixelwise_complexity_measure(
model_lb.eval()
with torch.no_grad():
# mask = img.to(model_gb.device).unsqueeze(dim=0).float()
mask = img.to(model_gb.device)
recon_gb: Tensor
......@@ -80,21 +80,14 @@ def pixelwise_complexity_measure(
recon_gb, mu_gb, logvar_gb = model_gb(mask)
recon_lb, mu_lb, logvar_lb = model_lb(mask)
# std_gb_mean = torch.exp(0.5 * logvar_gb).mean()
# std_lb_mean = torch.exp(0.5 * logvar_lb).mean()
# mu_gb_mean = mu_gb.mean()
# mu_lb_mean = mu_lb.mean()
max_px_fill = torch.ones_like(mask).sum().item()
abs_px_diff = (recon_gb - recon_lb).abs().sum().item()
# max_px_fill = torch.ones_like(mask).sum().item()
# complexity = abs_px_diff / max_px_fill
complexity = abs_px_diff / mask.sum()
# this equals complexity = (1 - fill_rate) * diff_px / max_px
if fill_ratio_norm:
complexity -= abs_px_diff * mask.sum().item() / np.power(max_px_fill, 2)
# complexity *= mask.sum().item() / max_px_fill
complexity *= mask.sum().item() / torch.ones_like(mask).sum().item()
if return_mean_std:
return (
......
......@@ -83,7 +83,9 @@ class BBoxTransform:
xcutmin, xcutmax = (
int(xmedian - n_med if xmedian >= n_med else 0),
int(xmedian + n_med if xmedian + n_med <= N else N,),
int(
xmedian + n_med if xmedian + n_med <= N else N,
),
)
if (ycutmax - ycutmin) % 2 != 0:
......@@ -98,9 +100,10 @@ class BBoxTransform:
dest_ymin, dest_ymax = int(n_med - squared_cut_y), int(n_med + squared_cut_y)
dest_xmin, dest_xmax = int(n_med - squared_cut_x), int(n_med + squared_cut_x)
squared_x[dest_ymin:dest_ymax, dest_xmin:dest_xmax,] = img[
ycutmin:ycutmax, xcutmin:xcutmax
]
squared_x[
dest_ymin:dest_ymax,
dest_xmin:dest_xmax,
] = img[ycutmin:ycutmax, xcutmin:xcutmax]
return transforms.F.to_pil_image(torch.from_numpy(squared_x).unsqueeze(dim=0))
......
......@@ -7,7 +7,7 @@ from matplotlib.pyplot import plot
import numpy as np
import torch
from matplotlib import cm
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, Subset
from torchvision.transforms import transforms
from complexity import (
......@@ -21,6 +21,7 @@ from models import load_models
from plot import (
create_vis,
visualize_reconstructions,
visualize_sort,
visualize_sort_mean_std,
visualize_sort_multidim,
)
......@@ -239,8 +240,11 @@ if __name__ == "__main__":
# ],
# )
indices = torch.randperm(len(dataset))[:15]
subset = Subset(dataset, indices)
sampler = RandomSampler(dataset, replacement=True, num_samples=100)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
subset_loader = DataLoader(subset, batch_size=1)
# visualize_reconstructions(
# data_loader, model_bn64, torch.argmax, plot_histogram=True
......@@ -249,23 +253,64 @@ if __name__ == "__main__":
# data_loader, model_bn64, torch.argmin, plot_histogram=True
# )
visualize_sort_mean_std(
data_loader,
pixelwise_complexity_measure,
"px_recon_complexity",
model_bn64,
model_bn16,
fill_ratio_norm=False,
return_mean_std=True,
)
# visualize_sort(
# visualize_sort_mean_std(
# data_loader,
# pixelwise_complexity_measure,
# "px_recon_complexity",
# model_bn64,
# model_bn16,
# fill_ratio_norm=False,
# return_mean_std=True,
# )
visualize_sort(
subset_loader,
pixelwise_complexity_measure,
"px_recon_complexity",
model_bn64,
model_bn16,
rows=1,
cols=len(indices),
fill_ratio_norm=False,
)
visualize_sort(
subset_loader,
compression_measure,
"compression",
rows=1,
cols=len(indices),
fill_ratio_norm=True,
)
visualize_sort(subset_loader, fft_measure, "fft", rows=1, cols=len(indices))
visualize_sort_multidim(
subset_loader,
[
("fft", fft_measure, []),
("compression", compression_measure, [True]),
(
"px_recon_comp",
pixelwise_complexity_measure,
[model_bn64, model_bn16, False],
),
],
max_norm=False,
rows=1,
cols=len(indices),
)
visualize_sort_multidim(
subset_loader,
[
("fft", fft_measure, []),
("compression", compression_measure, [True]),
(
"px_recon_comp",
pixelwise_complexity_measure,
[model_bn64, model_bn16, False],
),
],
max_norm=True,
rows=1,
cols=len(indices),
)
# visualize_sort_multidim(
# data_loader,
......
......@@ -112,9 +112,10 @@ def create_vis(
# TODO: instead of plotting each mask individually, create big image array/tensor
def plot_samples(masks: Tensor, ratings: npt.NDArray, labels: npt.NDArray = None):
def plot_samples(
masks: Tensor, ratings: npt.NDArray, rows=10, cols=10, labels: npt.NDArray = None
):
dpi = 150
rows = cols = 10
total = rows * cols
n_samples, _, y, x = masks.shape
......@@ -133,15 +134,17 @@ def plot_samples(masks: Tensor, ratings: npt.NDArray, labels: npt.NDArray = None
ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])
if labels is None:
plt.imshow(masks[idx].permute(1, 2, 0), extent=extent)
plt.imshow(
masks[idx].permute(1, 2, 0), extent=extent, cmap="gray", vmin=0, vmax=1
)
else:
mask = masks[idx][0] * (label_map[int(labels[idx].item())])
plt.imshow(
mask,
cmap="turbo",
extent=extent,
vmax=max_label,
vmin=0,
vmax=max_label,
)
rating = ratings[idx]
......@@ -295,6 +298,8 @@ def visualize_sort(
metric_name: str,
*fn_args: any,
plot_ratings=False,
rows=10,
cols=10,
**fn_kwargs: any,
):
n_samples = len(data_loader.sampler)
......@@ -326,14 +331,16 @@ def visualize_sort(
sort_idx = torch.argsort(ratings)
masks_sorted = masks.numpy()[sort_idx]
fig = plot_samples(masks_sorted, ratings.numpy()[sort_idx])
masks_sorted = masks[sort_idx]
fig = plot_samples(masks_sorted, ratings[sort_idx], rows=rows, cols=cols)
fig.savefig(f"results/{metric_name}_sort.png")
plt.close(fig)
if plot_recons:
recon_masks_sorted = recon_masks.numpy()[sort_idx]
fig_recon = plot_samples(recon_masks_sorted, ratings.numpy()[sort_idx])
recon_masks_sorted = recon_masks[sort_idx]
fig_recon = plot_samples(
recon_masks_sorted, ratings[sort_idx], rows=rows, cols=cols
)
fig_recon.savefig(f"results/{metric_name}_sort_recon.png")
plt.close(fig_recon)
......@@ -343,7 +350,10 @@ def visualize_sort_multidim(
measures: list[
tuple[str, Callable[[Tensor, any], tuple[torch.float32, Tensor]], any]
],
rows=10,
cols=10,
max_norm=False,
use_labels=False,
):
n_samples = len(data_loader.sampler)
n_dim = len(measures)
......@@ -372,13 +382,21 @@ def visualize_sort_multidim(
ax.set_xlabel(measures[0][0])
ax.set_ylabel(measures[1][0])
ax.set_zlabel(measures[2][0])
plt.savefig("results/3d_plot.png")
plt.savefig(f"results/3d_plot{'_norm' if max_norm else ''}.png")
plt.close()
measure_norm = torch.linalg.vector_norm(ratings, dim=1)
sort_idx = np.argsort(np.array(measure_norm))
rating_strings = [f"{r[0]:.4f}\n{r[1]:.4f}\n{r[2]:.4f}" for r in ratings[sort_idx]]
fig = plot_samples(masks.numpy()[sort_idx], rating_strings, labels[sort_idx])
fig.savefig(f"results/{n_dim}dim_{'_'.join([m[0] for m in measures])}_sort.png")
fig = plot_samples(
masks[sort_idx],
rating_strings,
labels=labels[sort_idx] if use_labels else None,
rows=rows,
cols=cols,
)
fig.savefig(
f"results/{n_dim}dim_{'_'.join([m[0] for m in measures])}_sort{'_norm' if max_norm else ''}.png"
)
plt.close(fig)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment