Skip to content
Snippets Groups Projects
Commit a3fbdf1c authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip local

parent f12fecfb
No related branches found
No related tags found
No related merge requests found
...@@ -3,11 +3,10 @@ from typing import Callable ...@@ -3,11 +3,10 @@ from typing import Callable
import numpy as np import numpy as np
import torch import torch
from imageio import imopen
from kornia.morphology import closing from kornia.morphology import closing
from PIL import Image from PIL import Image
from torch import Tensor from torch import Tensor
from torch.utils.data import DataLoader, Dataset from torch.utils.data import Dataset, Subset, WeightedRandomSampler
from torchvision.transforms import transforms from torchvision.transforms import transforms
from utils import bbox from utils import bbox
...@@ -29,7 +28,7 @@ class MPEG7ShapeDataset(Dataset): ...@@ -29,7 +28,7 @@ class MPEG7ShapeDataset(Dataset):
for file in paths: for file in paths:
fp = os.path.join(self.img_dir, file) fp = os.path.join(self.img_dir, file)
if os.path.isfile(fp): if os.path.isfile(fp):
label = file.split("-")[0] label = file.split("-")[0].lower()
self.filenames.append(fp) self.filenames.append(fp)
labels.append(label) labels.append(label)
...@@ -156,3 +155,17 @@ def load_mpeg7_data(): ...@@ -156,3 +155,17 @@ def load_mpeg7_data():
) )
return MPEG7ShapeDataset("../shape_complexity/data/mpeg7", transform) return MPEG7ShapeDataset("../shape_complexity/data/mpeg7", transform)
def get_weighted_sampler(dataset: MPEG7ShapeDataset, label_names: list):
label_indices = [dataset.label_index_dict[name] for name in label_names]
# indices = [
# idx for idx, label in enumerate(dataset.labels) if label in label_indices
# ]
# return Subset(dataset, indices)
label_weights = [
1 if label in label_indices else 0 for _, label in enumerate(dataset.labels)
]
return WeightedRandomSampler(weights=label_weights, num_samples=100)
import argparse import argparse
import glob
import os import os
import sys import sys
from typing import Generator
import matplotlib import matplotlib
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader
from matplotlib import cm from matplotlib import cm
from matplotlib.pyplot import fill
from PIL import Image as img
from PIL.Image import Image
from torchvision.transforms import transforms from torchvision.transforms import transforms
from complexity import ( from complexity import (
...@@ -20,10 +15,14 @@ from complexity import ( ...@@ -20,10 +15,14 @@ from complexity import (
multidim_complexity, multidim_complexity,
pixelwise_complexity_measure, pixelwise_complexity_measure,
) )
from data import get_dino_transforms, load_mpeg7_data from data import (
from models import CONVVAE, load_models get_dino_transforms,
from plot import create_vis, plot_samples, visualize_sort, visualize_sort_multidim get_weighted_sampler,
from utils import find_components, natsort load_mpeg7_data,
)
from models import load_models
from plot import create_vis, visualize_sort_multidim
from utils import find_components
LOAD_PRETRAINED = True LOAD_PRETRAINED = True
...@@ -223,8 +222,24 @@ if __name__ == "__main__": ...@@ -223,8 +222,24 @@ if __name__ == "__main__":
model_bn64.eval() model_bn64.eval()
model_bn16.eval() model_bn16.eval()
sampler = RandomSampler(test_dataset, replacement=True, num_samples=100) sampler = get_weighted_sampler(
data_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler) dataset,
[
"apple",
"bone",
"butterfly",
"hammer",
"pocket",
"device0",
"crown",
"hammer",
"tree",
"rat",
],
)
# sampler = RandomSampler(label_subset, replacement=True, num_samples=100)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
# visualize_sort( # visualize_sort(
# data_loader, # data_loader,
......
...@@ -107,13 +107,18 @@ def create_vis( ...@@ -107,13 +107,18 @@ def create_vis(
# TODO: instead of plotting each mask individually, create big image array/tensor # TODO: instead of plotting each mask individually, create big image array/tensor
def plot_samples(masks: Tensor, ratings: npt.NDArray, classes: npt.NDArray = None): # TODO: restrict to subset of labels.. (5-10?!) maybe 10 images of 10 classes..
def plot_samples(masks: Tensor, ratings: npt.NDArray, labels: npt.NDArray = None):
dpi = 150 dpi = 150
rows = cols = 10 rows = cols = 10
total = rows * cols total = rows * cols
n_samples, _, y, x = masks.shape n_samples, _, y, x = masks.shape
extent = (0, x - 1, 0, y - 1) extent = (0, x - 1, 0, y - 1)
label_map = {
v: i + 1 for i, v in enumerate({int(v): int(v) for v in labels}.keys())
}
max_label = len(label_map) + 1
if total != n_samples: if total != n_samples:
raise Exception("shape mismatch") raise Exception("shape mismatch")
...@@ -122,16 +127,22 @@ def plot_samples(masks: Tensor, ratings: npt.NDArray, classes: npt.NDArray = Non ...@@ -122,16 +127,22 @@ def plot_samples(masks: Tensor, ratings: npt.NDArray, classes: npt.NDArray = Non
for idx in np.arange(n_samples): for idx in np.arange(n_samples):
ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[]) ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])
if classes is None: if labels is None:
plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent) plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
else: else:
mask = masks[idx][0] * classes[idx].item() / classes.max().item() mask = masks[idx][0] * (label_map[int(labels[idx].item())])
plt.imshow(mask, extent=extent) plt.imshow(
mask,
cmap="turbo",
extent=extent,
vmax=max_label,
vmin=0,
)
rating = ratings[idx] rating = ratings[idx]
ax.set_title( ax.set_title(
rating if isinstance(rating, str) else f"{ratings[idx]:.4f}", rating if isinstance(rating, str) else f"{ratings[idx]:.4f}",
fontdict={"fontsize": 6, "color": "orange"}, fontdict={"fontsize": 6, "color": "orange" if labels is None else "white"},
y=0.2 if isinstance(rating, str) else 0.35, y=0.2 if isinstance(rating, str) else 0.35,
) )
...@@ -235,6 +246,6 @@ def visualize_sort_multidim( ...@@ -235,6 +246,6 @@ def visualize_sort_multidim(
sort_idx = np.argsort(np.array(measure_norm)) sort_idx = np.argsort(np.array(measure_norm))
rating_strings = [f"{r[0]:.4f}\n{r[1]:.4f}\n{r[2]:.4f}" for r in ratings[sort_idx]] rating_strings = [f"{r[0]:.4f}\n{r[1]:.4f}\n{r[2]:.4f}" for r in ratings[sort_idx]]
fig = plot_samples(masks.numpy()[sort_idx], rating_strings, labels) fig = plot_samples(masks.numpy()[sort_idx], rating_strings, labels[sort_idx])
fig.savefig(f"results/{n_dim}dim_{'_'.join([m[0] for m in measures])}_sort.png") fig.savefig(f"results/{n_dim}dim_{'_'.join([m[0] for m in measures])}_sort.png")
plt.close(fig) plt.close(fig)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment