Skip to content
Snippets Groups Projects
Commit 7d1a218e authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent 406133a8
No related branches found
No related tags found
No related merge requests found
......@@ -486,7 +486,7 @@ def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
return prec, rec, comp_data
def distance_measure(img: Tensor, model: VAE):
def l2_distance_measure(img: Tensor, model: VAE):
model.eval()
with torch.no_grad():
......@@ -503,7 +503,7 @@ def distance_measure(img: Tensor, model: VAE):
)
def compression_complexity(img: Tensor, fill_ratio_norm=False):
def compression_measure(img: Tensor, fill_ratio_norm=False):
np_img = img[0].numpy()
compressed = compress(np_img)
......@@ -534,6 +534,46 @@ def fft_measure(img: Tensor):
return mean_freq / 0.5, None
def pixelwise_complexity_measure(
img: Tensor,
model_gb: nn.Module,
model_lb: nn.Module,
fill_ratio_norm=False,
):
model_gb.eval()
model_lb.eval()
with torch.no_grad():
mask = img.to(device)
recon_gb: Tensor
recon_lb: Tensor
recon_gb, _, _ = model_gb(mask)
recon_lb, _, _ = model_lb(mask)
max_px_fill = torch.ones_like(mask).sum().item()
abs_px_diff = (recon_gb - recon_lb).abs().sum().item()
complexity = abs_px_diff / max_px_fill
# this equals complexity = (1 - fill_rate) * diff_px / max_px
if fill_ratio_norm:
complexity -= abs_px_diff * mask.sum().item() / np.power(max_px_fill, 2)
# complexity *= mask.sum().item() / max_px_fill
return (
complexity,
make_grid(
torch.stack(
[mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
).cpu(),
nrow=3,
padding=0,
),
)
def complexity_measure(
img: Tensor,
model_gb: nn.Module,
......@@ -712,7 +752,7 @@ def visualize_sort_3dim(
masks = torch.zeros((400, 1, 64, 64))
measures = torch.zeros((400, 3))
for i, (mask, _) in enumerate(data_loader, 0):
c_compress, _ = compression_complexity(mask, fill_ratio_norm=True)
c_compress, _ = compression_measure(mask, fill_ratio_norm=True)
c_fft, _ = fft_measure(mask)
# TODO: maybe exchange by diff or mean measure instead of precision
c_vae, mask_recon_grid = complexity_measure(
......@@ -750,8 +790,8 @@ LR = 1.5e-3
EPOCHS = 100
LOAD_PRETRAINED = True
# TODO: try out pixelwise loss again (in 3d as well)
# TODO: might be a good idea to implement a bbox cut preprocessing transform thingy
# TODO: try 2dim rep (fft, comp) (fft, vae) (comp, vae)
def main():
......@@ -807,7 +847,7 @@ def main():
fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png")
plt.close(fig)
plt.close(fig_recon)
plt.close()
visualize_sort(
data_loader,
......@@ -817,17 +857,27 @@ def main():
models[bn_lt],
fill_ratio_norm=True,
)
visualize_sort(data_loader, mean_precision, "mean_precision", list(models.values()))
visualize_sort(data_loader, fft_measure, "fft")
visualize_sort(data_loader, compression_complexity, "compression")
visualize_sort(
data_loader,
compression_complexity,
"compression_fill_norm",
fill_ratio_norm=True,
)
visualize_sort(data_loader, distance_measure, "latent_l2_distance", models[bn_gt])
# visualize_sort(
# data_loader,
# pixelwise_complexity_measure,
# "px_recon_complexity",
# models[bn_gt],
# models[bn_lt],
# fill_ratio_norm=True,
# )
# visualize_sort(data_loader, mean_precision, "mean_precision", list(models.values()))
# visualize_sort(data_loader, fft_measure, "fft")
# visualize_sort(data_loader, compression_measure, "compression")
# visualize_sort(
# data_loader,
# compression_measure,
# "compression_fill_norm",
# fill_ratio_norm=True,
# )
# visualize_sort(
# data_loader, l2_distance_measure, "latent_l2_distance", models[bn_gt]
# )
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment