from typing import Callable
import torch
import numpy as np

from bz2 import compress
from torch import Tensor
from torch import nn as nn
from torchvision.utils import make_grid

from models import CONVVAE


def l2_distance_measure(img: Tensor, model: CONVVAE):
    model.eval()

    with torch.no_grad():
        mask = img.to(model.device)
        recon, mean, _ = model(mask)
        # TODO: apply threshold here?!
        _, recon_mean, _ = model(recon)

        distance = torch.norm(mean - recon_mean, p=2)

        return (
            distance,
            make_grid(torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0),
        )


def compression_measure(img: Tensor, fill_ratio_norm=False):
    np_img = img[0].numpy()
    np_img_bytes = np_img.tobytes()
    compressed = compress(np_img_bytes)

    complexity = len(compressed) / len(np_img_bytes)

    if fill_ratio_norm:
        fill_ratio = np_img.sum().item() / np.ones_like(np_img).sum().item()
        return complexity * (1 - fill_ratio), None

    return complexity, None


def fft_measure(img: Tensor):
    np_img = img[0][0].numpy()
    fft = np.fft.fft2(np_img)

    fft_abs = np.abs(fft)

    n = fft.shape[0]
    pos_f_idx = n // 2
    df = np.fft.fftfreq(n=n)

    amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum()
    mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
    mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum

    mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2))

    # mean frequency in range 0 to 0.5
    return mean_freq / 0.5, None


def pixelwise_complexity_measure(
    img: Tensor,
    model_gb: nn.Module,
    model_lb: nn.Module,
    fill_ratio_norm=False,
    return_mean_std=False,
):
    model_gb.eval()
    model_lb.eval()

    with torch.no_grad():
        mask = img.to(model_gb.device)

        recon_gb: Tensor
        recon_lb: Tensor

        recon_gb, mu_gb, logvar_gb = model_gb(mask)
        recon_lb, mu_lb, logvar_lb = model_lb(mask)

        abs_px_diff = (recon_gb - recon_lb).abs().sum().item()

        # max_px_fill = torch.ones_like(mask).sum().item()
        # complexity = abs_px_diff / max_px_fill
        complexity = abs_px_diff / mask.sum()

        if fill_ratio_norm:
            complexity *= mask.sum().item() / torch.ones_like(mask).sum().item()

        if return_mean_std:
            return (
                complexity,
                (
                    [mu_gb, mu_lb],
                    [torch.exp(0.5 * logvar_gb), torch.exp(0.5 * logvar_lb)],
                ),
            )

        return (
            complexity,
            make_grid(
                torch.stack(
                    [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                ).cpu(),
                nrow=3,
                padding=0,
            ),
        )


def complexity_measure(
    img: Tensor,
    model_gb: CONVVAE,
    model_lb: CONVVAE,
    epsilon=0.4,
    fill_ratio_norm=False,
):
    model_gb.eval()
    model_lb.eval()

    with torch.no_grad():
        mask = img.to(model_gb.device)

        recon_gb, _, _ = model_gb(mask)
        recon_lb, _, _ = model_lb(mask)

        recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
        recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
        mask_bits = mask[0].cpu() > 0

        tp_gb = (mask_bits & recon_bits_gb).sum()
        fp_gb = (recon_bits_gb & ~mask_bits).sum()
        tp_lb = (mask_bits & recon_bits_lb).sum()
        fp_lb = (recon_bits_lb & ~mask_bits).sum()

        prec_gb = tp_gb / (tp_gb + fp_gb)
        prec_lb = tp_lb / (tp_lb + fp_lb)
        prec_gb = 0 if torch.isnan(prec_gb) else prec_gb
        prec_lb = 0 if torch.isnan(prec_lb) else prec_lb

        complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))

        if fill_ratio_norm:
            fill_ratio = mask.sum().item() / torch.ones_like(img).sum().item()
            complexity *= 1 - fill_ratio

        return (
            complexity,
            make_grid(
                torch.stack(
                    [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                ).cpu(),
                nrow=3,
                padding=0,
            ),
        )


def mean_precision(img: Tensor, models: list[CONVVAE], epsilon=0.4):
    mask = img.to(models[0].device)
    mask_bits = mask[0].cpu() > 0

    precisions = np.zeros(len(models))

    for i, model in enumerate(models):
        recon_gb, _, _ = model(mask)

        recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon

        tp = (mask_bits & recon_bits).sum()
        fp = (recon_bits & ~mask_bits).sum()

        prec = tp / (tp + fp)
        precisions[i] = prec

    return 1 - precisions.mean(), None


def multidim_complexity(
    img: Tensor,
    measures: list[
        tuple[str, Callable[[Tensor, any], tuple[torch.float32, Tensor]], any]
    ],
):
    n_dim = len(measures)
    ratings = torch.zeros((n_dim,))

    for i, (_, fn, args) in enumerate(measures):
        rating, _ = fn(img, *args)
        ratings[i] = rating

    return ratings