Skip to content
Snippets Groups Projects
Commit a3fbdf1c authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip local

parent f12fecfb
No related branches found
No related tags found
No related merge requests found
......@@ -3,11 +3,10 @@ from typing import Callable
import numpy as np
import torch
from imageio import imopen
from kornia.morphology import closing
from PIL import Image
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import Dataset, Subset, WeightedRandomSampler
from torchvision.transforms import transforms
from utils import bbox
......@@ -29,7 +28,7 @@ class MPEG7ShapeDataset(Dataset):
for file in paths:
fp = os.path.join(self.img_dir, file)
if os.path.isfile(fp):
label = file.split("-")[0]
label = file.split("-")[0].lower()
self.filenames.append(fp)
labels.append(label)
......@@ -156,3 +155,17 @@ def load_mpeg7_data():
)
return MPEG7ShapeDataset("../shape_complexity/data/mpeg7", transform)
def get_weighted_sampler(dataset: MPEG7ShapeDataset, label_names: list):
label_indices = [dataset.label_index_dict[name] for name in label_names]
# indices = [
# idx for idx, label in enumerate(dataset.labels) if label in label_indices
# ]
# return Subset(dataset, indices)
label_weights = [
1 if label in label_indices else 0 for _, label in enumerate(dataset.labels)
]
return WeightedRandomSampler(weights=label_weights, num_samples=100)
import argparse
import glob
import os
import sys
from typing import Generator
import matplotlib
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader
from matplotlib import cm
from matplotlib.pyplot import fill
from PIL import Image as img
from PIL.Image import Image
from torchvision.transforms import transforms
from complexity import (
......@@ -20,10 +15,14 @@ from complexity import (
multidim_complexity,
pixelwise_complexity_measure,
)
from data import get_dino_transforms, load_mpeg7_data
from models import CONVVAE, load_models
from plot import create_vis, plot_samples, visualize_sort, visualize_sort_multidim
from utils import find_components, natsort
from data import (
get_dino_transforms,
get_weighted_sampler,
load_mpeg7_data,
)
from models import load_models
from plot import create_vis, visualize_sort_multidim
from utils import find_components
LOAD_PRETRAINED = True
......@@ -223,8 +222,24 @@ if __name__ == "__main__":
model_bn64.eval()
model_bn16.eval()
sampler = RandomSampler(test_dataset, replacement=True, num_samples=100)
data_loader = DataLoader(test_dataset, batch_size=1, sampler=sampler)
sampler = get_weighted_sampler(
dataset,
[
"apple",
"bone",
"butterfly",
"hammer",
"pocket",
"device0",
"crown",
"hammer",
"tree",
"rat",
],
)
# sampler = RandomSampler(label_subset, replacement=True, num_samples=100)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
# visualize_sort(
# data_loader,
......
......@@ -107,13 +107,18 @@ def create_vis(
# TODO: instead of plotting each mask individually, create big image array/tensor
def plot_samples(masks: Tensor, ratings: npt.NDArray, classes: npt.NDArray = None):
# TODO: restrict to subset of labels.. (5-10?!) maybe 10 images of 10 classes..
def plot_samples(masks: Tensor, ratings: npt.NDArray, labels: npt.NDArray = None):
dpi = 150
rows = cols = 10
total = rows * cols
n_samples, _, y, x = masks.shape
extent = (0, x - 1, 0, y - 1)
label_map = {
v: i + 1 for i, v in enumerate({int(v): int(v) for v in labels}.keys())
}
max_label = len(label_map) + 1
if total != n_samples:
raise Exception("shape mismatch")
......@@ -122,16 +127,22 @@ def plot_samples(masks: Tensor, ratings: npt.NDArray, classes: npt.NDArray = Non
for idx in np.arange(n_samples):
ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])
if classes is None:
if labels is None:
plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
else:
mask = masks[idx][0] * classes[idx].item() / classes.max().item()
plt.imshow(mask, extent=extent)
mask = masks[idx][0] * (label_map[int(labels[idx].item())])
plt.imshow(
mask,
cmap="turbo",
extent=extent,
vmax=max_label,
vmin=0,
)
rating = ratings[idx]
ax.set_title(
rating if isinstance(rating, str) else f"{ratings[idx]:.4f}",
fontdict={"fontsize": 6, "color": "orange"},
fontdict={"fontsize": 6, "color": "orange" if labels is None else "white"},
y=0.2 if isinstance(rating, str) else 0.35,
)
......@@ -235,6 +246,6 @@ def visualize_sort_multidim(
sort_idx = np.argsort(np.array(measure_norm))
rating_strings = [f"{r[0]:.4f}\n{r[1]:.4f}\n{r[2]:.4f}" for r in ratings[sort_idx]]
fig = plot_samples(masks.numpy()[sort_idx], rating_strings, labels)
fig = plot_samples(masks.numpy()[sort_idx], rating_strings, labels[sort_idx])
fig.savefig(f"results/{n_dim}dim_{'_'.join([m[0] for m in measures])}_sort.png")
plt.close(fig)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment