Skip to content
Snippets Groups Projects
Commit b0042de1 authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent 293b9e2e
No related branches found
No related tags found
No related merge requests found
...@@ -5,4 +5,8 @@ activation_vis/out ...@@ -5,4 +5,8 @@ activation_vis/out
shape_complexity/results shape_complexity/results
shape_complexity/trained shape_complexity/trained
out out
critic critic
\ No newline at end of file results/
data/
data_backup/
__pycache__/
\ No newline at end of file
from typing import Callable
import torch
import numpy as np
from bz2 import compress
from torch import Tensor
from torch import nn as nn
from torchvision.utils import make_grid
from models import CONVVAE
def l2_distance_measure(img: Tensor, model: CONVVAE):
model.eval()
with torch.no_grad():
mask = img.to(model.device)
recon, mean, _ = model(mask)
# TODO: apply threshold here?!
_, recon_mean, _ = model(recon)
distance = torch.norm(mean - recon_mean, p=2)
return (
distance,
make_grid(torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0),
)
def compression_measure(img: Tensor, fill_ratio_norm=False):
np_img = img[0].numpy()
compressed = compress(np_img)
if fill_ratio_norm:
fill_ratio = np_img.sum().item() / np.ones_like(np_img).sum().item()
return len(compressed) * (1 - fill_ratio), None
return len(compressed), None
def fft_measure(img: Tensor):
np_img = img[0].numpy()
fft = np.fft.fft2(np_img)
fft_abs = np.abs(fft)
n = fft.shape[0]
pos_f_idx = n // 2
df = np.fft.fftfreq(n=n)
amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum()
mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2))
# mean frequency in range 0 to 0.5
return mean_freq / 0.5, None
def pixelwise_complexity_measure(
img: Tensor,
model_gb: nn.Module,
model_lb: nn.Module,
fill_ratio_norm=False,
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(model_gb.device).unsqueeze(dim=0).float()
recon_gb: Tensor
recon_lb: Tensor
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
max_px_fill = torch.ones_like(mask).sum().item()
abs_px_diff = (recon_gb - recon_lb).abs().sum().item()
complexity = abs_px_diff / max_px_fill
# this equals complexity = (1 - fill_rate) * diff_px / max_px
if fill_ratio_norm:
complexity -= abs_px_diff * mask.sum().item() / np.power(max_px_fill, 2)
# complexity *= mask.sum().item() / max_px_fill
return (
complexity,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
).cpu(),
nrow=3,
padding=0,
),
)
def complexity_measure(
img: Tensor,
model_gb: CONVVAE,
model_lb: CONVVAE,
epsilon=0.4,
fill_ratio_norm=False,
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(model_gb.device)
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
mask_bits = mask[0].cpu() > 0
tp_gb = (mask_bits & recon_bits_gb).sum()
fp_gb = (recon_bits_gb & ~mask_bits).sum()
tp_lb = (mask_bits & recon_bits_lb).sum()
fp_lb = (recon_bits_lb & ~mask_bits).sum()
prec_gb = tp_gb / (tp_gb + fp_gb)
prec_lb = tp_lb / (tp_lb + fp_lb)
prec_gb = 0 if torch.isnan(prec_gb) else prec_gb
prec_lb = 0 if torch.isnan(prec_lb) else prec_lb
complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
if fill_ratio_norm:
fill_ratio = mask.sum().item() / torch.ones_like(img).sum().item()
complexity *= 1 - fill_ratio
return (
complexity,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
).cpu(),
nrow=3,
padding=0,
),
)
def mean_precision(img: Tensor, models: list[CONVVAE], epsilon=0.4):
mask = img.to(models[0].device)
mask_bits = mask[0].cpu() > 0
precisions = np.zeros(len(models))
for i, model in enumerate(models):
recon_gb, _, _ = model(mask)
recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon
tp = (mask_bits & recon_bits).sum()
fp = (recon_bits & ~mask_bits).sum()
prec = tp / (tp + fp)
precisions[i] = prec
return 1 - precisions.mean(), None
def multidim_complexity(
img: Tensor,
measures: list[
tuple[str, Callable[[Tensor, any], tuple[torch.float32, Tensor]], any]
],
):
n_dim = len(measures)
ratings = torch.zeros((n_dim,))
for i, (_, fn, args) in enumerate(measures):
rating, _ = fn(img, *args)
ratings[i] = rating
return ratings
...@@ -2,9 +2,9 @@ import numpy as np ...@@ -2,9 +2,9 @@ import numpy as np
import torch import torch
from kornia.morphology import closing from kornia.morphology import closing
from torch import Tensor from torch import Tensor
from torchvision import transforms from torchvision.transforms import transforms
from visualize_results.utils import bbox from utils import bbox
class BBoxTransform: class BBoxTransform:
......
...@@ -2,65 +2,94 @@ import glob ...@@ -2,65 +2,94 @@ import glob
import os import os
import sys import sys
from typing import Generator from typing import Generator
from venv import create import matplotlib
from matplotlib.pyplot import fill
import numpy as np import numpy as np
from PIL import Image as img from PIL import Image as img
from PIL.Image import Image from PIL.Image import Image
import torch
from torchvision.transforms import transforms from torchvision.transforms import transforms
from complexity import (
from visualize_results.data import get_dino_transforms compression_measure,
from visualize_results.utils import find_components, natsort fft_measure,
multidim_complexity,
pixelwise_complexity_measure,
)
from matplotlib import cm
from data import get_dino_transforms
from utils import find_components, natsort
from models import load_models
sys.setrecursionlimit(1000000) sys.setrecursionlimit(1000000)
n_clusters = 5 n_clusters = 5
def create_vis(n_imgs, layer_range=range(12)) -> Generator[Image, Image, None]: def create_vis(
n_imgs, layer_range=range(12), sort=False
) -> Generator[Image, Image, None]:
""" """
yields three channel PIL image yields three channel PIL image
credit @wlad original by @wlad
""" """
inp = "./OutputDir" inp = "./data/dino/"
os.makedirs("kmeanimgs", exist_ok=True) out = "./results/comp_fft_px_sorted/"
rltrenner = img.open("./data/rltrenner.png") os.makedirs(out, exist_ok=True)
rltrenner = img.open("./data/static/rltrenner.png")
rltrenner = rltrenner.convert("RGB") rltrenner = rltrenner.convert("RGB")
yseperator = 20 yseperator = 20
print("Creating images") print("Creating images")
for idx, numimg in enumerate(range(n_imgs)): # TODO: add tqdm again.. for i in range(n_imgs): # TODO: maybe add tqdm again..
imagesvert = [] imagesvert = []
name = "img" + str(numimg) name = "img" + str(i)
aimg = img.open("OutputDir/" + name + ".png") aimg = img.open(inp + name + ".png")
aimg = aimg.convert("RGB") aimg = aimg.convert("RGB")
for depth in layer_range: for depth in layer_range:
imagesinline = [aimg] imagesinline = [aimg]
for attention_type in ["q", "k", "v"]: for attention_type in ["q", "k", "v"]:
if attention_type is not "v": complexities = []
imagesinline.append(rltrenner) _attention_images = []
img_paths = sorted(
glob.glob(
os.path.join(
inp, f"{name}{attention_type}depth{str(depth)}head*.png"
)
),
key=natsort,
)
for path in img_paths:
image = img.open(path)
if sort:
image, complexity = yield image
else:
image = yield image
templist = glob.glob( yield
os.path.join(
inp, f"{name}{attention_type}depth{str(depth)}head*.png"
)
).sort(key=natsort)
for timg in templist: # image = image.convert("RGB")
timg = img.open(timg)
timg = yield timg image = image.resize((480, 480), resample=img.NEAREST)
yield
timg = timg.convert("RGB") _attention_images.append(image)
if sort:
complexities.append(complexity)
timg = timg.resize((480, 480), resample=img.NEAREST) if sort:
imagesinline.append(timg) sort_idx = np.argsort(complexities)
_attention_images = [_attention_images[i] for i in sort_idx]
# yield imagesinline _attention_images.insert(0, rltrenner)
# imagesinline.insert(0, aimg) imagesinline.extend(_attention_images)
widths, heights = zip(*(i.size for i in imagesinline)) widths, heights = zip(*(i.size for i in imagesinline))
...@@ -87,33 +116,136 @@ def create_vis(n_imgs, layer_range=range(12)) -> Generator[Image, Image, None]: ...@@ -87,33 +116,136 @@ def create_vis(n_imgs, layer_range=range(12)) -> Generator[Image, Image, None]:
final_img.paste(im, (0, y_offset)) final_img.paste(im, (0, y_offset))
y_offset += im.size[1] + yseperator y_offset += im.size[1] + yseperator
final_img.save(os.path.join("kmeanimgs/", "img" + str(idx) + ".png")) final_img.save(os.path.join(out, "img" + str(i) + ".png"))
def main(): def multi_dim(max_norm=True, sort=False):
img_transformer = get_dino_transforms() img_transformer = get_dino_transforms()
image_generator = create_vis() image_generator = create_vis(n_imgs=25, sort=sort)
color_map = cm.get_cmap("plasma")
model_bn8, model_bn32 = load_models()
for img in image_generator: for img in image_generator:
np_img = np.array(img) np_img = np.array(img)
rgb_np_img = np.dstack((np_img, np_img, np_img))
cluster_img = (np_img / (255 / (n_clusters - 1))).astype(np.int8) cluster_img = (np_img / (255 / (n_clusters - 1))).astype(np.int8)
start_label = 1
complexiy_vectors = []
all_labels = np.zeros_like(np_img, dtype=np.uint8)
for i in range(n_clusters): for i in range(n_clusters):
cluster_image = (cluster_img == i).astype(np.int8) cluster_image = (cluster_img == i).astype(np.int8)
labels = find_components(cluster_image) labels = find_components(cluster_image, start_label)
for l in range(1, labels.max() + 1):
mask = (labels == l).astype(np.int8) for l in range(start_label, labels.max() + 1):
mask = (labels == l).astype(np.float32)
if not mask.sum() > 0:
continue
normalized_img = img_transformer(transforms.F.to_pil_image(mask)) normalized_img = img_transformer(transforms.F.to_pil_image(mask))
# TODO: calculate complexity for normalized image..
complexity = 1.0
np_img[labels == l] = complexity
# TODO: return manipulated PIL img / modify `img` complexity_vector = multidim_complexity(
image_generator.send(transforms.F.to_pil_image(np_img)) normalized_img,
[
("compression", compression_measure, [True]),
("fft", fft_measure, []),
(
"pixelwise",
pixelwise_complexity_measure,
[model_bn32, model_bn8, True],
),
],
)
complexiy_vectors.append(complexity_vector)
start_label = labels.max() + 1 if labels.max() > 0 else start_label
all_labels += labels
complexity_vectors = torch.stack(complexiy_vectors, dim=0)
if max_norm:
maxs = complexity_vectors.max(dim=0)
complexity_vectors /= maxs.values
complexity_norm = torch.linalg.vector_norm(complexity_vectors, dim=1)
# rank sorted complexities descending
color_idx = torch.argsort(torch.argsort(-complexity_norm))
color_values = (color_idx / len(color_idx)).numpy()
rgb_np_img[all_labels == 0, :] = np.array([0, 0, 0])
for l in range(1, all_labels.max() + 1):
r, g, b = matplotlib.colors.to_rgb(color_map(color_values[l - 1]))
rgb_np_img[all_labels == l, :] = np.array([r, g, b]) * 255.0
if sort:
image_generator.send(
(transforms.F.to_pil_image(rgb_np_img), complexity_norm.sum())
)
else:
image_generator.send(transforms.F.to_pil_image(rgb_np_img))
def single_dim(sort=False):
img_transformer = get_dino_transforms()
image_generator = create_vis(n_imgs=25, sort=sort)
color_map = cm.get_cmap("plasma")
model_bn8, model_bn32 = load_models()
for img in image_generator:
np_img = np.array(img)
rgb_np_img = np.dstack((np_img, np_img, np_img))
cluster_img = (np_img / (255 / (n_clusters - 1))).astype(np.int8)
start_label = 1
complexities = []
all_labels = np.zeros_like(np_img, dtype=np.uint8)
for i in range(n_clusters):
cluster_image = (cluster_img == i).astype(np.int8)
labels = find_components(cluster_image, start_label)
for l in range(start_label, labels.max() + 1):
mask = (labels == l).astype(np.float32)
if not mask.sum() > 0:
continue
normalized_img = img_transformer(transforms.F.to_pil_image(mask))
# complexity, _ = compression_measure(normalized_img, True)
complexity, _ = pixelwise_complexity_measure(
normalized_img, model_bn32, model_bn8, fill_ratio_norm=True
)
# complexity, _ = fft_measure(normalized_img)
complexities.append(complexity)
start_label = labels.max() + 1 if labels.max() > 0 else start_label
all_labels += labels
# rank sorted complexities descending
color_idx = np.argsort(np.argsort(-np.array(complexities)))
color_values = color_idx / len(color_idx)
rgb_np_img[all_labels == 0, :] = np.array([0, 0, 0])
for l in range(1, all_labels.max() + 1):
r, g, b = matplotlib.colors.to_rgb(color_map(color_values[l - 1]))
rgb_np_img[all_labels == l, :] = np.array([r, g, b]) * 255.0
continue if sort:
image_generator.send(
(transforms.F.to_pil_image(rgb_np_img), np.sum(complexities))
)
else:
image_generator.send(transforms.F.to_pil_image(rgb_np_img))
if __name__ == "__main__": if __name__ == "__main__":
main() multi_dim(sort=True)
...@@ -5,6 +5,8 @@ from torch import functional as F ...@@ -5,6 +5,8 @@ from torch import functional as F
class CONVVAE(nn.Module): class CONVVAE(nn.Module):
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def __init__( def __init__(
self, self,
bottleneck=2, bottleneck=2,
...@@ -111,3 +113,24 @@ class CONVVAE(nn.Module): ...@@ -111,3 +113,24 @@ class CONVVAE(nn.Module):
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return BCE + KLD return BCE + KLD
def to(self, device=None):
if device is not None:
self.device = device
return super().to(self.device)
def load_models():
bottlenecks = [8, 32]
models = {bn: CONVVAE(bottleneck=bn).to() for bn in bottlenecks}
for bn, model in models.items():
model.load_state_dict(
torch.load(
f"/home/markus/uni/navigation_project/shape_complexity/trained/CONVVAE_{bn}_split_data.pth"
)
)
model.eval()
return list(models.values())
from imageio import v3 as iio import re
import os
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import sys from imageio import v3 as iio
import re
def natsort(s, _nsre=re.compile("([0-9]+)")): def natsort(s, _nsre=re.compile("([0-9]+)")):
...@@ -35,8 +33,8 @@ def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: i ...@@ -35,8 +33,8 @@ def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: i
dfs(mask, x + dx[direction], y + dy[direction], labels, current_label) dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)
def find_components(mask: npt.NDArray): def find_components(mask: npt.NDArray, start_label=1, min_mask_pixels=16):
label = 0 label = start_label
n_rows, n_cols = mask.shape n_rows, n_cols = mask.shape
labels = np.zeros(mask.shape, dtype=np.int8) labels = np.zeros(mask.shape, dtype=np.int8)
...@@ -44,10 +42,20 @@ def find_components(mask: npt.NDArray): ...@@ -44,10 +42,20 @@ def find_components(mask: npt.NDArray):
for i in range(n_rows): for i in range(n_rows):
for j in range(n_cols): for j in range(n_cols):
if not labels[i][j] and mask[i][j]: if not labels[i][j] and mask[i][j]:
label += 1
dfs(mask, i, j, labels, label) dfs(mask, i, j, labels, label)
label += 1
max_label = labels.max()
subtraction_matrix = np.zeros_like(labels)
for l in range(start_label, max_label + 1):
if (labels == l).sum() < min_mask_pixels:
labels[labels == l] = 0
subtraction_matrix[labels > l] += 1
labels -= subtraction_matrix
labels[labels < start_label] = 0
return labels return labels.astype(np.uint8)
# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array # https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment