Skip to content
Snippets Groups Projects
Commit 4e9197d0 authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

remove stuff, add mpeg7 dataset

parent e218f84e
No related branches found
No related tags found
No related merge requests found
......@@ -2,9 +2,11 @@ import os
# from zlib import compress
from bz2 import compress
import string
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.transforms import Transform
import numpy as np
import numpy.typing as npt
import torch
......@@ -13,8 +15,9 @@ from kornia.morphology import closing
from PIL import Image
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, RandomSampler, Dataset
from torchvision.datasets import ImageFolder
from torchvision.io import read_image
from torchvision.transforms import transforms
from torchvision.utils import make_grid, save_image
......@@ -265,6 +268,65 @@ class CloseTransform:
return transforms.F.to_pil_image(closing(x, self.kernel))
class MPEG7ShapeDataset(Dataset):
img_dir: str
filenames: list[str] = []
labels: list[str] = []
label_dict: dict[str]
transform: Transform = None
def __init__(self, img_dir, transform=None):
self.img_dir = img_dir
self.transform = transform
paths = os.listdir(self.img_dir)
labels = []
for file in paths:
fp = os.path.join(self.img_dir, file)
if os.path.isfile(fp):
label = file.split("-")[0]
self.filenames.append(fp)
labels.append(label)
label_name_dict = dict.fromkeys(labels)
self.label_dict = {i: v for (i, v) in enumerate(label_name_dict.keys())}
self.label_index_dict = {v: i for (i, v) in self.label_dict.items()}
self.labels = [self.label_index_dict[l] for l in labels]
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
img_path = self.filenames[idx]
gif = Image.open(img_path)
gif.convert("RGB")
label = self.labels[idx]
if self.transform:
image = self.transform(gif)
return image, label
def load_mpeg7_data():
transform = transforms.Compose(
[
transforms.Grayscale(),
# transforms.RandomApply([CloseTransform()], p=0.25),
transforms.Resize(
(64, 64), interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
]
)
dataset = MPEG7ShapeDataset("shape_complexity/data/mpeg7", transform)
data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
return data_loader, dataset
def load_data():
transform = transforms.Compose(
[
......@@ -474,7 +536,6 @@ def complexity_measure(
model_lb: nn.Module,
img: Tensor,
epsilon=0.4,
save_preliminary=False,
):
model_gb.eval()
model_lb.eval()
......@@ -488,25 +549,6 @@ def complexity_measure(
recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
mask_bits = mask[0].cpu() > 0
if save_preliminary:
save_image(
torch.stack(
[mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()]
).cpu(),
f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
)
save_image(
torch.stack(
[
(mask_bits & recon_bits_gb).float(),
(recon_bits_gb & ~mask_bits).float(),
(mask_bits & recon_bits_lb).float(),
(recon_bits_lb & ~mask_bits).float(),
]
).cpu(),
f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
)
tp_gb = (mask_bits & recon_bits_gb).sum()
fp_gb = (recon_bits_gb & ~mask_bits).sum()
tp_lb = (mask_bits & recon_bits_lb).sum()
......@@ -515,18 +557,9 @@ def complexity_measure(
prec_gb = tp_gb / (tp_gb + fp_gb)
prec_lb = tp_lb / (tp_lb + fp_lb)
complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
complexity_lb = 1 - prec_lb
complexity_gb = 1 - prec_gb
# 1 - (0.4 - abs(0.4 - 0.7)) = 0.9
# 1 - 0.7 = 0.3
return (
complexity,
complexity_lb,
complexity_gb,
prec_gb - prec_lb,
prec_lb,
prec_gb,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
......@@ -730,9 +763,7 @@ def visualize_sort_3dim(
c_compress = compression_complexity(mask)
c_fft = fft_measure(mask)
# TODO: maybe exchange by diff or mean measure instead of precision
c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
model_gb, model_lb, mask
)
c_vae, mask_recon_grid = complexity_measure(model_gb, model_lb, mask)
masks_recon[i] = mask_recon_grid
masks[i] = mask[0]
measures[i] = torch.tensor([c_compress, c_fft, c_vae])
......@@ -767,181 +798,55 @@ def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
masks = torch.zeros((400, 3, 64, 192))
complexities = torch.zeros((400,))
diffs = []
for i, (mask, _) in enumerate(data_loader, 0):
complexity, _, _, diff, mask_recon_grid = complexity_measure(
complexity, mask_recon_grid = complexity_measure(
model_gb, model_lb, mask, save_preliminary=True
)
masks[i] = mask_recon_grid
diffs.append(diff)
complexities[i] = complexity
sort_idx = np.argsort(np.array(complexities))
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(diffs)), np.sort(diffs))
plt.xlabel("images")
plt.ylabel("prec difference (L-H)")
plt.savefig("shape_complexity/results/diff_plot.png")
plt.clf()
return plot_samples(masks_sorted, complexities[sort_idx])
def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
masks = torch.zeros((400, 3, 64, 192))
complexities = torch.zeros((400,))
complexities_lb = torch.zeros((400,))
complexities_gb = torch.zeros((400,))
diffs = []
prec_lbs = []
prec_gbs = []
for i, (mask, _) in enumerate(data_loader, 0):
(
complexity,
lb,
gb,
diff,
prec_lb,
prec_gb,
mask_recon_grid,
) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
masks[i] = mask_recon_grid
diffs.append(diff)
prec_lbs.append(prec_lb)
prec_gbs.append(prec_gb)
complexities[i] = complexity
complexities_lb[i] = lb
complexities_gb[i] = gb
sort_idx = np.argsort(np.array(complexities))
sort_idx_lb = np.argsort(np.array(complexities_lb))
sort_idx_gb = np.argsort(np.array(complexities_gb))
masks_sorted = masks.numpy()[sort_idx]
masks_sorted_lb = masks.numpy()[sort_idx_lb]
masks_sorted_gb = masks.numpy()[sort_idx_gb]
diff_sort_idx = np.argsort(diffs)
# plt.savefig("shape_complexity/results/diff_plot.png")
# plt.clf
fig, ax1 = plt.subplots()
ax2 = ax1.twinx()
ax1.plot(
np.arange(len(prec_lbs)),
np.array(prec_lbs)[diff_sort_idx],
label=f"bottleneck {model_lb.bottleneck}",
)
ax1.plot(
np.arange(len(prec_gbs)),
np.array(prec_gbs)[diff_sort_idx],
label=f"bottleneck {model_gb.bottleneck}",
)
ax2.plot(
np.arange(len(diffs)),
np.sort(diffs),
color="red",
label="prec difference (H - L)",
)
ax1.legend(loc="lower left")
ax2.legend(loc="lower right")
ax1.set_ylabel("precision")
ax2.set_ylabel("prec difference (H-L)")
plt.savefig("shape_complexity/results/prec_plot.png")
plt.clf()
fig = plot_samples(masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs.png")
plt.close(fig)
fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb])
fig.savefig("shape_complexity/results/lb.png")
plt.close(fig)
fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb])
fig.savefig("shape_complexity/results/gb.png")
plt.close(fig)
def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
recon_masks = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64))
complexities = torch.zeros((400,))
diffs = np.zeros((400,))
prec_gbs = np.zeros((400,))
prec_lbs = np.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
(
complexity,
_,
_,
diff,
prec_lb,
prec_gb,
mask_recon_grid,
) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
recon_masks[i] = mask_recon_grid
masks[i] = mask[0]
diffs[i] = diff
prec_gbs[i] = prec_gb
prec_lbs[i] = prec_lb
complexities[i] = complexity
sort_idx = np.argsort(np.array(complexities))
masks_sorted = masks.numpy()[sort_idx]
recon_masks_sorted = recon_masks.numpy()[sort_idx]
# group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
# bin_edges = [-np.inf, 0.0, 0.05, np.inf]
# bins = np.digitize(diffs, bins=bin_edges, right=True)
# for i in range(bins.min(), bins.max() + 1):
# bin_idx = bins == i
# binned_prec_gb = prec_gbs[bin_idx]
# prec_mean = binned_prec_gb.mean()
# prec_idx = prec_gbs > prec_mean
# binned_masks_high = recon_masks[bin_idx & prec_idx]
# binned_masks_low = recon_masks[bin_idx & ~prec_idx]
# save_image(
# binned_masks_high,
# f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
# padding=10,
# )
# save_image(
# binned_masks_low,
# f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
# padding=10,
# )
# diff_sort_idx = np.argsort(diffs)
# fig, ax1 = plt.subplots()
# ax2 = ax1.twinx()
# ax1.plot(
# np.arange(len(prec_lbs)),
# np.array(prec_lbs)[diff_sort_idx],
# label=f"bottleneck {model_lb.bottleneck}",
# )
# ax1.plot(
# np.arange(len(prec_gbs)),
# np.array(prec_gbs)[diff_sort_idx],
# label=f"bottleneck {model_gb.bottleneck}",
# )
# ax2.plot(
# np.arange(len(diffs)),
# np.sort(diffs),
# color="red",
# label="prec difference (H - L)",
# )
# ax1.legend(loc="lower left")
# ax2.legend(loc="lower right")
# ax1.set_ylabel("precision")
# ax2.set_ylabel("prec difference (H-L)")
# ax1.set_xlabel("image")
# plt.savefig("shape_complexity/results/prec_plot.png")
# plt.tight_layout(pad=2)
# plt.clf()
fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
fig.savefig("shape_complexity/results/abs_recon.png")
plt.close(fig)
......@@ -951,9 +856,15 @@ def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
plt.close(fig)
LR = 1e-3
EPOCHS = 10
LOAD_PRETRAINED = True
LR = 1.5e-3
EPOCHS = 100
LOAD_PRETRAINED = False
# TODO: build pixel ratio normalization for large black areas
# -> ideally, this fixes the compression metric
# TODO: try out pixelwise loss again (in 3d as well)
# TODO: might be a good idea to implement a bbox cut preprocessing transform thingy
def main():
......@@ -961,7 +872,8 @@ def main():
models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}
data_loader, dataset = load_data()
# data_loader, dataset = load_data()
data_loader, dataset = load_mpeg7_data()
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment