Skip to content
Snippets Groups Projects
Commit b0042de1 authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent 293b9e2e
No related branches found
No related tags found
No related merge requests found
......@@ -5,4 +5,8 @@ activation_vis/out
shape_complexity/results
shape_complexity/trained
out
critic
\ No newline at end of file
critic
results/
data/
data_backup/
__pycache__/
\ No newline at end of file
from typing import Callable
import torch
import numpy as np
from bz2 import compress
from torch import Tensor
from torch import nn as nn
from torchvision.utils import make_grid
from models import CONVVAE
def l2_distance_measure(img: Tensor, model: CONVVAE):
model.eval()
with torch.no_grad():
mask = img.to(model.device)
recon, mean, _ = model(mask)
# TODO: apply threshold here?!
_, recon_mean, _ = model(recon)
distance = torch.norm(mean - recon_mean, p=2)
return (
distance,
make_grid(torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0),
)
def compression_measure(img: Tensor, fill_ratio_norm=False):
np_img = img[0].numpy()
compressed = compress(np_img)
if fill_ratio_norm:
fill_ratio = np_img.sum().item() / np.ones_like(np_img).sum().item()
return len(compressed) * (1 - fill_ratio), None
return len(compressed), None
def fft_measure(img: Tensor):
np_img = img[0].numpy()
fft = np.fft.fft2(np_img)
fft_abs = np.abs(fft)
n = fft.shape[0]
pos_f_idx = n // 2
df = np.fft.fftfreq(n=n)
amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum()
mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2))
# mean frequency in range 0 to 0.5
return mean_freq / 0.5, None
def pixelwise_complexity_measure(
img: Tensor,
model_gb: nn.Module,
model_lb: nn.Module,
fill_ratio_norm=False,
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(model_gb.device).unsqueeze(dim=0).float()
recon_gb: Tensor
recon_lb: Tensor
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
max_px_fill = torch.ones_like(mask).sum().item()
abs_px_diff = (recon_gb - recon_lb).abs().sum().item()
complexity = abs_px_diff / max_px_fill
# this equals complexity = (1 - fill_rate) * diff_px / max_px
if fill_ratio_norm:
complexity -= abs_px_diff * mask.sum().item() / np.power(max_px_fill, 2)
# complexity *= mask.sum().item() / max_px_fill
return (
complexity,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
).cpu(),
nrow=3,
padding=0,
),
)
def complexity_measure(
img: Tensor,
model_gb: CONVVAE,
model_lb: CONVVAE,
epsilon=0.4,
fill_ratio_norm=False,
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(model_gb.device)
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
mask_bits = mask[0].cpu() > 0
tp_gb = (mask_bits & recon_bits_gb).sum()
fp_gb = (recon_bits_gb & ~mask_bits).sum()
tp_lb = (mask_bits & recon_bits_lb).sum()
fp_lb = (recon_bits_lb & ~mask_bits).sum()
prec_gb = tp_gb / (tp_gb + fp_gb)
prec_lb = tp_lb / (tp_lb + fp_lb)
prec_gb = 0 if torch.isnan(prec_gb) else prec_gb
prec_lb = 0 if torch.isnan(prec_lb) else prec_lb
complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
if fill_ratio_norm:
fill_ratio = mask.sum().item() / torch.ones_like(img).sum().item()
complexity *= 1 - fill_ratio
return (
complexity,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
).cpu(),
nrow=3,
padding=0,
),
)
def mean_precision(img: Tensor, models: list[CONVVAE], epsilon=0.4):
mask = img.to(models[0].device)
mask_bits = mask[0].cpu() > 0
precisions = np.zeros(len(models))
for i, model in enumerate(models):
recon_gb, _, _ = model(mask)
recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon
tp = (mask_bits & recon_bits).sum()
fp = (recon_bits & ~mask_bits).sum()
prec = tp / (tp + fp)
precisions[i] = prec
return 1 - precisions.mean(), None
def multidim_complexity(
img: Tensor,
measures: list[
tuple[str, Callable[[Tensor, any], tuple[torch.float32, Tensor]], any]
],
):
n_dim = len(measures)
ratings = torch.zeros((n_dim,))
for i, (_, fn, args) in enumerate(measures):
rating, _ = fn(img, *args)
ratings[i] = rating
return ratings
......@@ -2,9 +2,9 @@ import numpy as np
import torch
from kornia.morphology import closing
from torch import Tensor
from torchvision import transforms
from torchvision.transforms import transforms
from visualize_results.utils import bbox
from utils import bbox
class BBoxTransform:
......
......@@ -2,65 +2,94 @@ import glob
import os
import sys
from typing import Generator
from venv import create
import matplotlib
from matplotlib.pyplot import fill
import numpy as np
from PIL import Image as img
from PIL.Image import Image
import torch
from torchvision.transforms import transforms
from visualize_results.data import get_dino_transforms
from visualize_results.utils import find_components, natsort
from complexity import (
compression_measure,
fft_measure,
multidim_complexity,
pixelwise_complexity_measure,
)
from matplotlib import cm
from data import get_dino_transforms
from utils import find_components, natsort
from models import load_models
sys.setrecursionlimit(1000000)
n_clusters = 5
def create_vis(n_imgs, layer_range=range(12)) -> Generator[Image, Image, None]:
def create_vis(
n_imgs, layer_range=range(12), sort=False
) -> Generator[Image, Image, None]:
"""
yields three channel PIL image
credit @wlad
original by @wlad
"""
inp = "./OutputDir"
os.makedirs("kmeanimgs", exist_ok=True)
rltrenner = img.open("./data/rltrenner.png")
inp = "./data/dino/"
out = "./results/comp_fft_px_sorted/"
os.makedirs(out, exist_ok=True)
rltrenner = img.open("./data/static/rltrenner.png")
rltrenner = rltrenner.convert("RGB")
yseperator = 20
print("Creating images")
for idx, numimg in enumerate(range(n_imgs)): # TODO: add tqdm again..
for i in range(n_imgs): # TODO: maybe add tqdm again..
imagesvert = []
name = "img" + str(numimg)
name = "img" + str(i)
aimg = img.open("OutputDir/" + name + ".png")
aimg = img.open(inp + name + ".png")
aimg = aimg.convert("RGB")
for depth in layer_range:
imagesinline = [aimg]
for attention_type in ["q", "k", "v"]:
if attention_type is not "v":
imagesinline.append(rltrenner)
complexities = []
_attention_images = []
img_paths = sorted(
glob.glob(
os.path.join(
inp, f"{name}{attention_type}depth{str(depth)}head*.png"
)
),
key=natsort,
)
for path in img_paths:
image = img.open(path)
if sort:
image, complexity = yield image
else:
image = yield image
templist = glob.glob(
os.path.join(
inp, f"{name}{attention_type}depth{str(depth)}head*.png"
)
).sort(key=natsort)
yield
for timg in templist:
timg = img.open(timg)
# image = image.convert("RGB")
timg = yield timg
yield
image = image.resize((480, 480), resample=img.NEAREST)
timg = timg.convert("RGB")
_attention_images.append(image)
if sort:
complexities.append(complexity)
timg = timg.resize((480, 480), resample=img.NEAREST)
imagesinline.append(timg)
if sort:
sort_idx = np.argsort(complexities)
_attention_images = [_attention_images[i] for i in sort_idx]
# yield imagesinline
# imagesinline.insert(0, aimg)
_attention_images.insert(0, rltrenner)
imagesinline.extend(_attention_images)
widths, heights = zip(*(i.size for i in imagesinline))
......@@ -87,33 +116,136 @@ def create_vis(n_imgs, layer_range=range(12)) -> Generator[Image, Image, None]:
final_img.paste(im, (0, y_offset))
y_offset += im.size[1] + yseperator
final_img.save(os.path.join("kmeanimgs/", "img" + str(idx) + ".png"))
final_img.save(os.path.join(out, "img" + str(i) + ".png"))
def main():
def multi_dim(max_norm=True, sort=False):
img_transformer = get_dino_transforms()
image_generator = create_vis()
image_generator = create_vis(n_imgs=25, sort=sort)
color_map = cm.get_cmap("plasma")
model_bn8, model_bn32 = load_models()
for img in image_generator:
np_img = np.array(img)
rgb_np_img = np.dstack((np_img, np_img, np_img))
cluster_img = (np_img / (255 / (n_clusters - 1))).astype(np.int8)
start_label = 1
complexiy_vectors = []
all_labels = np.zeros_like(np_img, dtype=np.uint8)
for i in range(n_clusters):
cluster_image = (cluster_img == i).astype(np.int8)
labels = find_components(cluster_image)
for l in range(1, labels.max() + 1):
mask = (labels == l).astype(np.int8)
labels = find_components(cluster_image, start_label)
for l in range(start_label, labels.max() + 1):
mask = (labels == l).astype(np.float32)
if not mask.sum() > 0:
continue
normalized_img = img_transformer(transforms.F.to_pil_image(mask))
# TODO: calculate complexity for normalized image..
complexity = 1.0
np_img[labels == l] = complexity
# TODO: return manipulated PIL img / modify `img`
image_generator.send(transforms.F.to_pil_image(np_img))
complexity_vector = multidim_complexity(
normalized_img,
[
("compression", compression_measure, [True]),
("fft", fft_measure, []),
(
"pixelwise",
pixelwise_complexity_measure,
[model_bn32, model_bn8, True],
),
],
)
complexiy_vectors.append(complexity_vector)
start_label = labels.max() + 1 if labels.max() > 0 else start_label
all_labels += labels
complexity_vectors = torch.stack(complexiy_vectors, dim=0)
if max_norm:
maxs = complexity_vectors.max(dim=0)
complexity_vectors /= maxs.values
complexity_norm = torch.linalg.vector_norm(complexity_vectors, dim=1)
# rank sorted complexities descending
color_idx = torch.argsort(torch.argsort(-complexity_norm))
color_values = (color_idx / len(color_idx)).numpy()
rgb_np_img[all_labels == 0, :] = np.array([0, 0, 0])
for l in range(1, all_labels.max() + 1):
r, g, b = matplotlib.colors.to_rgb(color_map(color_values[l - 1]))
rgb_np_img[all_labels == l, :] = np.array([r, g, b]) * 255.0
if sort:
image_generator.send(
(transforms.F.to_pil_image(rgb_np_img), complexity_norm.sum())
)
else:
image_generator.send(transforms.F.to_pil_image(rgb_np_img))
def single_dim(sort=False):
img_transformer = get_dino_transforms()
image_generator = create_vis(n_imgs=25, sort=sort)
color_map = cm.get_cmap("plasma")
model_bn8, model_bn32 = load_models()
for img in image_generator:
np_img = np.array(img)
rgb_np_img = np.dstack((np_img, np_img, np_img))
cluster_img = (np_img / (255 / (n_clusters - 1))).astype(np.int8)
start_label = 1
complexities = []
all_labels = np.zeros_like(np_img, dtype=np.uint8)
for i in range(n_clusters):
cluster_image = (cluster_img == i).astype(np.int8)
labels = find_components(cluster_image, start_label)
for l in range(start_label, labels.max() + 1):
mask = (labels == l).astype(np.float32)
if not mask.sum() > 0:
continue
normalized_img = img_transformer(transforms.F.to_pil_image(mask))
# complexity, _ = compression_measure(normalized_img, True)
complexity, _ = pixelwise_complexity_measure(
normalized_img, model_bn32, model_bn8, fill_ratio_norm=True
)
# complexity, _ = fft_measure(normalized_img)
complexities.append(complexity)
start_label = labels.max() + 1 if labels.max() > 0 else start_label
all_labels += labels
# rank sorted complexities descending
color_idx = np.argsort(np.argsort(-np.array(complexities)))
color_values = color_idx / len(color_idx)
rgb_np_img[all_labels == 0, :] = np.array([0, 0, 0])
for l in range(1, all_labels.max() + 1):
r, g, b = matplotlib.colors.to_rgb(color_map(color_values[l - 1]))
rgb_np_img[all_labels == l, :] = np.array([r, g, b]) * 255.0
continue
if sort:
image_generator.send(
(transforms.F.to_pil_image(rgb_np_img), np.sum(complexities))
)
else:
image_generator.send(transforms.F.to_pil_image(rgb_np_img))
if __name__ == "__main__":
main()
multi_dim(sort=True)
......@@ -5,6 +5,8 @@ from torch import functional as F
class CONVVAE(nn.Module):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def __init__(
self,
bottleneck=2,
......@@ -111,3 +113,24 @@ class CONVVAE(nn.Module):
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD
def to(self, device=None):
if device is not None:
self.device = device
return super().to(self.device)
def load_models():
bottlenecks = [8, 32]
models = {bn: CONVVAE(bottleneck=bn).to() for bn in bottlenecks}
for bn, model in models.items():
model.load_state_dict(
torch.load(
f"/home/markus/uni/navigation_project/shape_complexity/trained/CONVVAE_{bn}_split_data.pth"
)
)
model.eval()
return list(models.values())
from imageio import v3 as iio
import os
import re
import numpy as np
import numpy.typing as npt
import sys
import re
from imageio import v3 as iio
def natsort(s, _nsre=re.compile("([0-9]+)")):
......@@ -35,8 +33,8 @@ def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: i
dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)
def find_components(mask: npt.NDArray):
label = 0
def find_components(mask: npt.NDArray, start_label=1, min_mask_pixels=16):
label = start_label
n_rows, n_cols = mask.shape
labels = np.zeros(mask.shape, dtype=np.int8)
......@@ -44,10 +42,20 @@ def find_components(mask: npt.NDArray):
for i in range(n_rows):
for j in range(n_cols):
if not labels[i][j] and mask[i][j]:
label += 1
dfs(mask, i, j, labels, label)
label += 1
max_label = labels.max()
subtraction_matrix = np.zeros_like(labels)
for l in range(start_label, max_label + 1):
if (labels == l).sum() < min_mask_pixels:
labels[labels == l] = 0
subtraction_matrix[labels > l] += 1
labels -= subtraction_matrix
labels[labels < start_label] = 0
return labels
return labels.astype(np.uint8)
# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment