Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
M
minerl-indexing
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Markus Rothgänger
minerl-indexing
Commits
a3fbdf1c
Commit
a3fbdf1c
authored
2 years ago
by
Markus Rothgänger
Browse files
Options
Downloads
Patches
Plain Diff
wip local
parent
f12fecfb
No related branches found
Branches containing commit
No related tags found
No related merge requests found
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
complexity_estimator/data.py
+16
-3
16 additions, 3 deletions
complexity_estimator/data.py
complexity_estimator/main.py
+27
-12
27 additions, 12 deletions
complexity_estimator/main.py
complexity_estimator/plot.py
+17
-6
17 additions, 6 deletions
complexity_estimator/plot.py
with
60 additions
and
21 deletions
complexity_estimator/data.py
+
16
−
3
View file @
a3fbdf1c
...
...
@@ -3,11 +3,10 @@ from typing import Callable
import
numpy
as
np
import
torch
from
imageio
import
imopen
from
kornia.morphology
import
closing
from
PIL
import
Image
from
torch
import
Tensor
from
torch.utils.data
import
Data
Loader
,
Dataset
from
torch.utils.data
import
Data
set
,
Subset
,
WeightedRandomSampler
from
torchvision.transforms
import
transforms
from
utils
import
bbox
...
...
@@ -29,7 +28,7 @@ class MPEG7ShapeDataset(Dataset):
for
file
in
paths
:
fp
=
os
.
path
.
join
(
self
.
img_dir
,
file
)
if
os
.
path
.
isfile
(
fp
):
label
=
file
.
split
(
"
-
"
)[
0
]
label
=
file
.
split
(
"
-
"
)[
0
]
.
lower
()
self
.
filenames
.
append
(
fp
)
labels
.
append
(
label
)
...
...
@@ -156,3 +155,17 @@ def load_mpeg7_data():
)
return
MPEG7ShapeDataset
(
"
../shape_complexity/data/mpeg7
"
,
transform
)
def
get_weighted_sampler
(
dataset
:
MPEG7ShapeDataset
,
label_names
:
list
):
label_indices
=
[
dataset
.
label_index_dict
[
name
]
for
name
in
label_names
]
# indices = [
# idx for idx, label in enumerate(dataset.labels) if label in label_indices
# ]
# return Subset(dataset, indices)
label_weights
=
[
1
if
label
in
label_indices
else
0
for
_
,
label
in
enumerate
(
dataset
.
labels
)
]
return
WeightedRandomSampler
(
weights
=
label_weights
,
num_samples
=
100
)
This diff is collapsed.
Click to expand it.
complexity_estimator/main.py
+
27
−
12
View file @
a3fbdf1c
import
argparse
import
glob
import
os
import
sys
from
typing
import
Generator
import
matplotlib
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
,
RandomSampler
from
torch.utils.data
import
DataLoader
from
matplotlib
import
cm
from
matplotlib.pyplot
import
fill
from
PIL
import
Image
as
img
from
PIL.Image
import
Image
from
torchvision.transforms
import
transforms
from
complexity
import
(
...
...
@@ -20,10 +15,14 @@ from complexity import (
multidim_complexity
,
pixelwise_complexity_measure
,
)
from
data
import
get_dino_transforms
,
load_mpeg7_data
from
models
import
CONVVAE
,
load_models
from
plot
import
create_vis
,
plot_samples
,
visualize_sort
,
visualize_sort_multidim
from
utils
import
find_components
,
natsort
from
data
import
(
get_dino_transforms
,
get_weighted_sampler
,
load_mpeg7_data
,
)
from
models
import
load_models
from
plot
import
create_vis
,
visualize_sort_multidim
from
utils
import
find_components
LOAD_PRETRAINED
=
True
...
...
@@ -223,8 +222,24 @@ if __name__ == "__main__":
model_bn64
.
eval
()
model_bn16
.
eval
()
sampler
=
RandomSampler
(
test_dataset
,
replacement
=
True
,
num_samples
=
100
)
data_loader
=
DataLoader
(
test_dataset
,
batch_size
=
1
,
sampler
=
sampler
)
sampler
=
get_weighted_sampler
(
dataset
,
[
"
apple
"
,
"
bone
"
,
"
butterfly
"
,
"
hammer
"
,
"
pocket
"
,
"
device0
"
,
"
crown
"
,
"
hammer
"
,
"
tree
"
,
"
rat
"
,
],
)
# sampler = RandomSampler(label_subset, replacement=True, num_samples=100)
data_loader
=
DataLoader
(
dataset
,
batch_size
=
1
,
sampler
=
sampler
)
# visualize_sort(
# data_loader,
...
...
This diff is collapsed.
Click to expand it.
complexity_estimator/plot.py
+
17
−
6
View file @
a3fbdf1c
...
...
@@ -107,13 +107,18 @@ def create_vis(
# TODO: instead of plotting each mask individually, create big image array/tensor
def
plot_samples
(
masks
:
Tensor
,
ratings
:
npt
.
NDArray
,
classes
:
npt
.
NDArray
=
None
):
# TODO: restrict to subset of labels.. (5-10?!) maybe 10 images of 10 classes..
def
plot_samples
(
masks
:
Tensor
,
ratings
:
npt
.
NDArray
,
labels
:
npt
.
NDArray
=
None
):
dpi
=
150
rows
=
cols
=
10
total
=
rows
*
cols
n_samples
,
_
,
y
,
x
=
masks
.
shape
extent
=
(
0
,
x
-
1
,
0
,
y
-
1
)
label_map
=
{
v
:
i
+
1
for
i
,
v
in
enumerate
({
int
(
v
):
int
(
v
)
for
v
in
labels
}.
keys
())
}
max_label
=
len
(
label_map
)
+
1
if
total
!=
n_samples
:
raise
Exception
(
"
shape mismatch
"
)
...
...
@@ -122,16 +127,22 @@ def plot_samples(masks: Tensor, ratings: npt.NDArray, classes: npt.NDArray = Non
for
idx
in
np
.
arange
(
n_samples
):
ax
=
fig
.
add_subplot
(
rows
,
cols
,
idx
+
1
,
xticks
=
[],
yticks
=
[])
if
c
la
sse
s
is
None
:
if
la
bel
s
is
None
:
plt
.
imshow
(
masks
[
idx
][
0
],
cmap
=
plt
.
cm
.
gray
,
extent
=
extent
)
else
:
mask
=
masks
[
idx
][
0
]
*
classes
[
idx
].
item
()
/
classes
.
max
().
item
()
plt
.
imshow
(
mask
,
extent
=
extent
)
mask
=
masks
[
idx
][
0
]
*
(
label_map
[
int
(
labels
[
idx
].
item
())])
plt
.
imshow
(
mask
,
cmap
=
"
turbo
"
,
extent
=
extent
,
vmax
=
max_label
,
vmin
=
0
,
)
rating
=
ratings
[
idx
]
ax
.
set_title
(
rating
if
isinstance
(
rating
,
str
)
else
f
"
{
ratings
[
idx
]
:
.
4
f
}
"
,
fontdict
=
{
"
fontsize
"
:
6
,
"
color
"
:
"
orange
"
},
fontdict
=
{
"
fontsize
"
:
6
,
"
color
"
:
"
orange
"
if
labels
is
None
else
"
white
"
},
y
=
0.2
if
isinstance
(
rating
,
str
)
else
0.35
,
)
...
...
@@ -235,6 +246,6 @@ def visualize_sort_multidim(
sort_idx
=
np
.
argsort
(
np
.
array
(
measure_norm
))
rating_strings
=
[
f
"
{
r
[
0
]
:
.
4
f
}
\n
{
r
[
1
]
:
.
4
f
}
\n
{
r
[
2
]
:
.
4
f
}
"
for
r
in
ratings
[
sort_idx
]]
fig
=
plot_samples
(
masks
.
numpy
()[
sort_idx
],
rating_strings
,
labels
)
fig
=
plot_samples
(
masks
.
numpy
()[
sort_idx
],
rating_strings
,
labels
[
sort_idx
]
)
fig
.
savefig
(
f
"
results/
{
n_dim
}
dim_
{
'
_
'
.
join
([
m
[
0
]
for
m
in
measures
])
}
_sort.png
"
)
plt
.
close
(
fig
)
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment