Skip to content
Snippets Groups Projects
Commit 647cdb15 authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent 7bf5c381
No related branches found
No related tags found
No related merge requests found
...@@ -3,6 +3,7 @@ import os ...@@ -3,6 +3,7 @@ import os
# from zlib import compress # from zlib import compress
from bz2 import compress from bz2 import compress
import matplotlib
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
...@@ -18,7 +19,7 @@ from torchvision.transforms import transforms ...@@ -18,7 +19,7 @@ from torchvision.transforms import transforms
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
device = torch.device("cuda") device = torch.device("cuda")
# matplotlib.use("Agg") matplotlib.use("Agg")
dx = [+1, 0, -1, 0] dx = [+1, 0, -1, 0]
dy = [0, +1, 0, -1] dy = [0, +1, 0, -1]
...@@ -463,29 +464,12 @@ def fft_measure(img: Tensor): ...@@ -463,29 +464,12 @@ def fft_measure(img: Tensor):
total = np.sqrt(np.power(avg_x, 2) + np.power(avg_y, 2)) total = np.sqrt(np.power(avg_x, 2) + np.power(avg_y, 2))
df = np.fft.fftfreq(n=len(total)) df = np.fft.fftfreq(n=len(total))
pos_f_idx = len(total) // 2
# mean frequency in range 0 to 0.5 # mean frequency in range 0 to 0.5
mean_freq = (total * df)[: len(total) // 2].sum() / total[: len(total) // 2].sum() mean_freq = (total * df)[:pos_f_idx].sum() / total[:pos_f_idx].sum()
return mean_freq / 0.5 return mean_freq / 0.5
# magnitude = np.fft.fftshift(np.abs(fft))
# spectrum = np.log(1 + magnitude)
# flattened_spectrum = np.sort(magnitude.flatten())
# plt.plot(np.linspace(flattened_spectrum.min(), flattened_spectrum.max(), num=len(flattened_spectrum)), flattened_spectrum)
# frequencies = np.fft.fftfreq()
# plt.plot(radial_profile(magnitude))
# plt.show()
# M, N = np_img.shape
# total_freq_value = spectrum.sum()
# inner_sum = spectrum[M // 3 : 2 * (M // 3), N // 3 : 2 * (N // 3)].sum()
# return (total_freq_value - inner_sum) / total_freq_value
def complexity_measure( def complexity_measure(
model_gb: nn.Module, model_gb: nn.Module,
...@@ -555,6 +539,26 @@ def complexity_measure( ...@@ -555,6 +539,26 @@ def complexity_measure(
) )
def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4):
mask = img.to(device)
mask_bits = mask[0].cpu() > 0
precisions = np.zeros(len(models))
for i, model in enumerate(models):
recon_gb, _, _ = model(mask)
recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon
tp = (mask_bits & recon_bits).sum()
fp = (recon_bits & ~mask_bits).sum()
prec = tp / (tp + fp)
precisions[i] = prec
return 1 - precisions.mean()
def complexity_measure_diff( def complexity_measure_diff(
model_gb: nn.Module, model_gb: nn.Module,
model_lb: nn.Module, model_lb: nn.Module,
...@@ -667,12 +671,31 @@ def visualize_sort_fft(data_loader: DataLoader): ...@@ -667,12 +671,31 @@ def visualize_sort_fft(data_loader: DataLoader):
plt.plot(np.arange(len(distances)), np.sort(distances.numpy())) plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
plt.xlabel("images") plt.xlabel("images")
plt.ylabel("compression length") plt.ylabel("mean unidirectional frequency")
plt.savefig("shape_complexity/results/fft_plot.png") plt.savefig("shape_complexity/results/fft_plot.png")
return plot_samples(masks_sorted, distances.numpy()[sort_idx]) return plot_samples(masks_sorted, distances.numpy()[sort_idx])
def visualize_sort_mean_precision(models: list[nn.Module], data_loader: DataLoader):
masks = torch.zeros((400, 1, 64, 64))
precisions = torch.zeros((400,))
for i, (mask, _) in enumerate(data_loader, 0):
masks[i] = mask[0]
precisions[i] = mean_precision(models, mask)
sort_idx = torch.argsort(precisions)
masks_sorted = masks.numpy()[sort_idx]
plt.plot(np.arange(len(precisions)), np.sort(precisions.numpy()))
plt.xlabel("images")
plt.ylabel("mean precision")
plt.savefig("shape_complexity/results/mean_prec_plot.png")
return plot_samples(masks_sorted, precisions.numpy()[sort_idx])
def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module): def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
masks_recon = torch.zeros((400, 3, 64, 192)) masks_recon = torch.zeros((400, 3, 64, 192))
masks = torch.zeros((400, 1, 64, 64)) masks = torch.zeros((400, 1, 64, 64))
...@@ -725,6 +748,7 @@ def visualize_sort_3dim( ...@@ -725,6 +748,7 @@ def visualize_sort_3dim(
ax.set_ylabel("FFT ratio") ax.set_ylabel("FFT ratio")
ax.set_zlabel(f"VAE ratio {model_gb.bottleneck}/{model_lb.bottleneck}") ax.set_zlabel(f"VAE ratio {model_gb.bottleneck}/{model_lb.bottleneck}")
plt.savefig("shape_complexity/results/3d_plot.png") plt.savefig("shape_complexity/results/3d_plot.png")
plt.clf()
sort_idx = np.argsort(np.array(measure_norm)) sort_idx = np.argsort(np.array(measure_norm))
recon_masks_sorted = masks_recon.numpy()[sort_idx] recon_masks_sorted = masks_recon.numpy()[sort_idx]
...@@ -757,6 +781,7 @@ def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module): ...@@ -757,6 +781,7 @@ def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
plt.xlabel("images") plt.xlabel("images")
plt.ylabel("prec difference (L-H)") plt.ylabel("prec difference (L-H)")
plt.savefig("shape_complexity/results/diff_plot.png") plt.savefig("shape_complexity/results/diff_plot.png")
plt.clf()
return plot_samples(masks_sorted, complexities[sort_idx]) return plot_samples(masks_sorted, complexities[sort_idx])
...@@ -993,8 +1018,14 @@ def main(): ...@@ -993,8 +1018,14 @@ def main():
fig, fig_recon = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt]) fig, fig_recon = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt])
fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png") fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png") fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png")
plt.close(fig)
plt.close(fig_recon)
fig = visualize_sort_mean_precision(list(models.values()), data_loader)
fig.savefig(f"shape_complexity/results/sort_mean_prec.png")
plt.close(fig)
fig = visualize_sort_fft(data_loader) fig = visualize_sort_fft(data_loader)
fig.savefig(f"shape_complexity/results/sort_fft.png") fig.savefig(f"shape_complexity/results/sort_fft.png")
plt.close(fig)
fig = visualize_sort_compression(data_loader) fig = visualize_sort_compression(data_loader)
fig.savefig(f"shape_complexity/results/sort_compression.png") fig.savefig(f"shape_complexity/results/sort_compression.png")
# fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt]) # fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment