from typing import Callable import torch import numpy as np from bz2 import compress from torch import Tensor from torch import nn as nn from torchvision.utils import make_grid from models import CONVVAE def l2_distance_measure(img: Tensor, model: CONVVAE): model.eval() with torch.no_grad(): mask = img.to(model.device) recon, mean, _ = model(mask) # TODO: apply threshold here?! _, recon_mean, _ = model(recon) distance = torch.norm(mean - recon_mean, p=2) return ( distance, make_grid(torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0), ) def compression_measure(img: Tensor, fill_ratio_norm=False): np_img = img[0].numpy() np_img_bytes = np_img.tobytes() compressed = compress(np_img_bytes) complexity = len(compressed) / len(np_img_bytes) if fill_ratio_norm: fill_ratio = np_img.sum().item() / np.ones_like(np_img).sum().item() return complexity * (1 - fill_ratio), None return complexity, None def fft_measure(img: Tensor): np_img = img[0][0].numpy() fft = np.fft.fft2(np_img) fft_abs = np.abs(fft) n = fft.shape[0] pos_f_idx = n // 2 df = np.fft.fftfreq(n=n) amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum() mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2)) # mean frequency in range 0 to 0.5 return mean_freq / 0.5, None def pixelwise_complexity_measure( img: Tensor, model_gb: nn.Module, model_lb: nn.Module, fill_ratio_norm=False, return_mean_std=False, ): model_gb.eval() model_lb.eval() with torch.no_grad(): mask = img.to(model_gb.device) recon_gb: Tensor recon_lb: Tensor recon_gb, mu_gb, logvar_gb = model_gb(mask) recon_lb, mu_lb, logvar_lb = model_lb(mask) abs_px_diff = (recon_gb - recon_lb).abs().sum().item() # max_px_fill = torch.ones_like(mask).sum().item() # complexity = abs_px_diff / max_px_fill complexity = abs_px_diff / mask.sum() if fill_ratio_norm: complexity *= mask.sum().item() / torch.ones_like(mask).sum().item() if return_mean_std: return ( complexity, ( [mu_gb, mu_lb], [torch.exp(0.5 * logvar_gb), torch.exp(0.5 * logvar_lb)], ), ) return ( complexity, make_grid( torch.stack( [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)] ).cpu(), nrow=3, padding=0, ), ) def complexity_measure( img: Tensor, model_gb: CONVVAE, model_lb: CONVVAE, epsilon=0.4, fill_ratio_norm=False, ): model_gb.eval() model_lb.eval() with torch.no_grad(): mask = img.to(model_gb.device) recon_gb, _, _ = model_gb(mask) recon_lb, _, _ = model_lb(mask) recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon mask_bits = mask[0].cpu() > 0 tp_gb = (mask_bits & recon_bits_gb).sum() fp_gb = (recon_bits_gb & ~mask_bits).sum() tp_lb = (mask_bits & recon_bits_lb).sum() fp_lb = (recon_bits_lb & ~mask_bits).sum() prec_gb = tp_gb / (tp_gb + fp_gb) prec_lb = tp_lb / (tp_lb + fp_lb) prec_gb = 0 if torch.isnan(prec_gb) else prec_gb prec_lb = 0 if torch.isnan(prec_lb) else prec_lb complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb)) if fill_ratio_norm: fill_ratio = mask.sum().item() / torch.ones_like(img).sum().item() complexity *= 1 - fill_ratio return ( complexity, make_grid( torch.stack( [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)] ).cpu(), nrow=3, padding=0, ), ) def mean_precision(img: Tensor, models: list[CONVVAE], epsilon=0.4): mask = img.to(models[0].device) mask_bits = mask[0].cpu() > 0 precisions = np.zeros(len(models)) for i, model in enumerate(models): recon_gb, _, _ = model(mask) recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon tp = (mask_bits & recon_bits).sum() fp = (recon_bits & ~mask_bits).sum() prec = tp / (tp + fp) precisions[i] = prec return 1 - precisions.mean(), None def multidim_complexity( img: Tensor, measures: list[ tuple[str, Callable[[Tensor, any], tuple[torch.float32, Tensor]], any] ], ): n_dim = len(measures) ratings = torch.zeros((n_dim,)) for i, (_, fn, args) in enumerate(measures): rating, _ = fn(img, *args) ratings[i] = rating return ratings