import os

# from zlib import compress
from bz2 import compress

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
import torch.nn.functional as F
from kornia.morphology import closing
from PIL import Image
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.utils import make_grid, save_image

device = torch.device("cuda")
matplotlib.use("Agg")

dx = [+1, 0, -1, 0]
dy = [0, +1, 0, -1]


# perform depth first search for each candidate/unlabeled region
# reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation
def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int):
    n_rows, n_cols = mask.shape
    if x < 0 or x == n_rows:
        return
    if y < 0 or y == n_cols:
        return
    if labels[x][y] or not mask[x][y]:
        return  # already labeled or not marked with 1 in image

    # mark the current cell
    labels[x][y] = current_label

    # recursively mark the neighbors
    for direction in range(4):
        dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)


def find_components(mask: npt.NDArray):
    label = 0

    n_rows, n_cols = mask.shape
    labels = np.zeros(mask.shape, dtype=np.int8)

    for i in range(n_rows):
        for j in range(n_cols):
            if not labels[i][j] and mask[i][j]:
                label += 1
                dfs(mask, i, j, labels, label)

    return labels


# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
def bbox(img):
    max_x, max_y = img.shape
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    rmin = rmin - 1 if rmin > 0 else rmin
    cmin = cmin - 1 if cmin > 0 else cmin
    rmax = rmax + 1 if rmax < max_x else rmax
    cmax = cmax + 1 if cmax < max_y else cmax

    return rmin, rmax, cmin, cmax


def extract_single_masks(labels: npt.NDArray):
    masks = []
    for l in range(labels.max() + 1):
        mask = (labels == l).astype(np.int8)
        rmin, rmax, cmin, cmax = bbox(mask)
        masks.append(mask[rmin : rmax + 1, cmin : cmax + 1])

    return masks


class VAE(nn.Module):
    """
    https://github.com/pytorch/examples/blob/main/vae/main.py
    """

    def __init__(self, bottleneck=2, image_dim=4096):
        super(VAE, self).__init__()

        self.bottleneck = bottleneck
        self.image_dim = image_dim

        self.prelim_encode = nn.Sequential(
            nn.Flatten(), nn.Linear(image_dim, 400), nn.ReLU()
        )
        self.encode_mu = nn.Sequential(nn.Linear(400, bottleneck))
        self.encode_logvar = nn.Sequential(nn.Linear(400, bottleneck))

        self.decode = nn.Sequential(
            nn.Linear(bottleneck, 400),
            nn.ReLU(),
            nn.Linear(400, image_dim),
            nn.Sigmoid(),
        )

    def encode(self, x):
        # h1 = F.relu(self.encode(x))
        # return self.encode_mu(h1), self.encode_logvar(h1)
        x = self.prelim_encode(x)
        return self.encode_mu(x), self.encode_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    # Reconstruction + KL divergence losses summed over all elements and batch
    def loss(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum")

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD


class CONVVAE(nn.Module):
    def __init__(
        self,
        bottleneck=2,
    ):
        super(CONVVAE, self).__init__()

        self.bottleneck = bottleneck
        self.feature_dim = 6 * 6 * 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 30x30x16
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 14x14x32
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 6x6x64
        )
        # self.conv4 = nn.Sequential(
        #     nn.Conv2d(32, self.bottleneck, 5),
        #     nn.ReLU(),
        #     nn.MaxPool2d((2, 2), return_indices=True),  # -> 1x1xbottleneck
        # )

        self.encode_mu = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.feature_dim, self.bottleneck),
        )
        self.encode_logvar = nn.Sequential(
            nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
        )

        self.decode_linear = nn.Linear(self.bottleneck, self.feature_dim)

        # self.decode4 = nn.Sequential(
        #     nn.ConvTranspose2d(self.bottleneck, 32, 5),
        #     nn.ReLU(),
        # )
        self.decode3 = nn.Sequential(
            nn.ConvTranspose2d(64, 64, 2, stride=2),  # -> 12x12x64
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3),  # -> 14x14x32
            nn.ReLU(),
        )
        self.decode2 = nn.Sequential(
            nn.ConvTranspose2d(32, 32, 2, stride=2),  # -> 28x28x32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3),  # -> 30x30x16
            nn.ReLU(),
        )
        self.decode1 = nn.Sequential(
            nn.ConvTranspose2d(16, 16, 2, stride=2),  # -> 60x60x16
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 5),  # -> 64x64x1
            nn.Sigmoid(),
        )

    def encode(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        # x, idx4 = self.conv4(x)
        mu = self.encode_mu(x)
        logvar = self.encode_logvar(x)

        return mu, logvar

    def decode(self, z: Tensor):
        z = self.decode_linear(z)
        z = z.view((-1, 64, 6, 6))
        # z = F.max_unpool2d(z, idx4, (2, 2))
        # z = self.decode4(z)
        z = self.decode3(z)
        z = self.decode2(z)
        z = self.decode1(z)
        # z = z.view(-1, 128, 1, 1)
        # return self.decode_conv(z)
        return z

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    def loss(self, recon_x, x, mu, logvar):
        """https://github.com/pytorch/examples/blob/main/vae/main.py"""
        BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD


class CloseTransform:
    kernel = torch.ones(5, 5)

    def __init__(self, kernel=None):
        if kernel is not None:
            self.kernel = kernel

    def __call__(self, x):
        x = transforms.F.to_tensor(x)

        if len(x.shape) < 4:
            return transforms.F.to_pil_image(
                closing(x.unsqueeze(dim=0), self.kernel).squeeze(dim=0)
            )

        return transforms.F.to_pil_image(closing(x, self.kernel))


def load_data():
    transform = transforms.Compose(
        [
            transforms.Grayscale(),
            transforms.RandomApply([CloseTransform()], p=0.25),
            transforms.Resize(
                (64, 64), interpolation=transforms.InterpolationMode.BILINEAR
            ),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
        ]
    )

    trajectories = (
        []
        if False
        else [
            # "v3_subtle_iceberg_lettuce_nymph-6_203-2056",
            "v3_absolute_grape_changeling-16_2277-4441",
            "v3_content_squash_angel-3_16074-17640",
            "v3_smooth_kale_loch_ness_monster-1_4439-6272",
            "v3_cute_breadfruit_spirit-6_17090-19102",
            "v3_key_nectarine_spirit-2_7081-9747",
            "v3_subtle_iceberg_lettuce_nymph-6_3819-6049",
            "v3_juvenile_apple_angel-30_396415-398113",
            "v3_subtle_iceberg_lettuce_nymph-6_6100-8068",
        ]
    )

    datasets = []
    for trj in trajectories:
        datasets.append(
            ImageFolder(
                f"activation_vis/out/critic/masks/{trj}/0/4", transform=transform
            )
        )

    datasets.append(
        ImageFolder("shape_complexity/data/simple_shapes", transform=transform)
    )

    dataset = torch.utils.data.ConcatDataset(datasets)

    data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
    return data_loader, dataset


def train(epoch, model: VAE or CONVVAE, optimizer, data_loader, log_interval=40):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = model.loss(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(data_loader.dataset),
                    100.0 * batch_idx / len(data_loader),
                    loss.item() / len(data),
                )
            )

    print(
        "====> Epoch: {} Average loss: {:.4f}".format(
            epoch, train_loss / len(data_loader.dataset)
        )
    )


def test(epoch, models: list[CONVVAE] or list[VAE], dataset, save_results=False):
    for model in models:
        model.eval()
    test_loss = [0 for _ in models]

    test_batch_size = 32
    sampler = RandomSampler(dataset, replacement=True, num_samples=64)
    test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=sampler)
    comp_data = None

    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)

            for j, model in enumerate(models):
                recon_batch, mu, logvar = model(data)
                test_loss[j] += model.loss(recon_batch, data, mu, logvar).item()

                if i == 0:
                    n = min(data.size(0), 20)
                    if comp_data == None:
                        comp_data = data[:n]
                    comp_data = torch.cat(
                        [comp_data, recon_batch.view(test_batch_size, 1, 64, 64)[:n]]
                    )

            if i == 0 and save_results:
                if not os.path.exists("results"):
                    os.makedirs("results")
                save_image(
                    comp_data.cpu(),
                    "results/reconstruction_" + str(epoch) + ".png",
                    nrow=min(data.size(0), 20),
                )

    for i, model in enumerate(models):
        test_loss[i] /= len(test_loader.dataset)
        print(f"====> Test set loss model {model.bottleneck}: {test_loss[i]:.4f}")

    if save_results:
        plt.figure()
        bns = [m.bottleneck for m in models]
        plt.plot(bns, test_loss)
        plt.xticks(bns)
        plt.savefig("shape_complexity/results/vae_graph.png")
        plt.clf()


def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
    model.eval()
    image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path)))
    labels = find_components(image[0])
    single_masks = extract_single_masks(labels)
    mask = transforms.F.to_tensor(
        transforms.F.resize(
            transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)),
            (64, 64),
        )
    )

    with torch.no_grad():
        mask = mask.to(device)
        recon_x, _, _ = model(mask)
        recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon
        mask_bits = mask.cpu().numpy() > 0

        TP = (mask_bits & recon_bits).sum()
        FP = (recon_bits & ~mask_bits).sum()
        FN = (mask_bits & ~recon_bits).sum()

        prec = TP / (TP + FP)
        rec = TP / (TP + FN)
        # loss = pixelwise_loss(recon_x, mask)
        comp_data = torch.cat(
            [mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)]
        )
        # print(f"mask loss: {loss:.4f}")

        return prec, rec, comp_data


def distance_measure(model: VAE, img: Tensor):
    model.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon, mean, _ = model(mask)
        # TODO: apply threshold here?!
        _, recon_mean, _ = model(recon)

        distance = torch.norm(mean - recon_mean, p=2)

        return (
            distance,
            make_grid(torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0),
        )


def compression_complexity(img: Tensor):
    np_img = img[0].numpy()
    compressed = compress(np_img)

    return len(compressed)


def fft_measure(img: Tensor):
    np_img = img[0][0].numpy()
    fft = np.fft.fft2(np_img)

    fft_abs = np.abs(fft)

    n = fft.shape[0]
    pos_f_idx = n // 2
    df = np.fft.fftfreq(n=n)

    amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum()
    mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
    mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum

    mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2))

    # mean frequency in range 0 to 0.5
    return mean_freq / 0.5


def complexity_measure(
    model_gb: nn.Module,
    model_lb: nn.Module,
    img: Tensor,
    epsilon=0.4,
    save_preliminary=False,
):
    model_gb.eval()
    model_lb.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon_gb, _, _ = model_gb(mask)
        recon_lb, _, _ = model_lb(mask)

        recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
        recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
        mask_bits = mask[0].cpu() > 0

        if save_preliminary:
            save_image(
                torch.stack(
                    [mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()]
                ).cpu(),
                f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
            )
            save_image(
                torch.stack(
                    [
                        (mask_bits & recon_bits_gb).float(),
                        (recon_bits_gb & ~mask_bits).float(),
                        (mask_bits & recon_bits_lb).float(),
                        (recon_bits_lb & ~mask_bits).float(),
                    ]
                ).cpu(),
                f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
            )

        tp_gb = (mask_bits & recon_bits_gb).sum()
        fp_gb = (recon_bits_gb & ~mask_bits).sum()
        tp_lb = (mask_bits & recon_bits_lb).sum()
        fp_lb = (recon_bits_lb & ~mask_bits).sum()

        prec_gb = tp_gb / (tp_gb + fp_gb)
        prec_lb = tp_lb / (tp_lb + fp_lb)
        complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
        complexity_lb = 1 - prec_lb
        complexity_gb = 1 - prec_gb
        #           1 - (0.4 - abs(0.4 - 0.7))    = 0.9
        #           1 - 0.7                       = 0.3

        return (
            complexity,
            complexity_lb,
            complexity_gb,
            prec_gb - prec_lb,
            prec_lb,
            prec_gb,
            make_grid(
                torch.stack(
                    [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                ).cpu(),
                nrow=3,
                padding=0,
            ),
        )


def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4):
    mask = img.to(device)
    mask_bits = mask[0].cpu() > 0

    precisions = np.zeros(len(models))

    for i, model in enumerate(models):
        recon_gb, _, _ = model(mask)

        recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon

        tp = (mask_bits & recon_bits).sum()
        fp = (recon_bits & ~mask_bits).sum()

        prec = tp / (tp + fp)
        precisions[i] = prec

    return 1 - precisions.mean()


def complexity_measure_diff(
    model_gb: nn.Module,
    model_lb: nn.Module,
    img: Tensor,
):
    model_gb.eval()
    model_lb.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon_gb, _, _ = model_gb(mask)
        recon_lb, _, _ = model_lb(mask)

        diff = torch.abs((recon_gb - recon_lb).cpu().sum())

        return (
            diff,
            make_grid(
                torch.stack(
                    [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                ).cpu(),
                nrow=3,
                padding=0,
            ),
        )


def plot_samples(masks: Tensor, complexities: npt.NDArray):
    dpi = 150
    rows = cols = 20
    total = rows * cols
    n_samples, _, y, x = masks.shape

    extent = (0, x - 1, 0, y - 1)

    if total != n_samples:
        raise Exception("shape mismatch")

    fig = plt.figure(figsize=(32, 16), dpi=dpi)
    for idx in np.arange(n_samples):
        ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])

        plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
        ax.set_title(
            f"{complexities[idx]:.4f}",
            fontdict={"fontsize": 6, "color": "orange"},
            y=0.35,
        )

    fig.patch.set_facecolor("#292929")
    height_px = y * rows
    width_px = x * cols
    fig.set_size_inches(width_px / (dpi / 2), height_px / (dpi / 2), forward=True)
    fig.tight_layout(pad=0)

    return fig


def visualize_sort_mean(data_loader: DataLoader, model: VAE):
    recon_masks = torch.zeros((400, 3, 64, 128))
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        distance, mask_recon_grid = distance_measure(model, mask)
        masks[i] = mask[0]
        recon_masks[i] = mask_recon_grid
        distances[i] = distance

    sort_idx = torch.argsort(distances)
    recon_masks_sorted = recon_masks.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    plt.xlabel("images")
    plt.ylabel("latent mean L2 distance")
    plt.savefig("shape_complexity/results/distance_plot.png")

    return (
        plot_samples(masks_sorted, distances.numpy()[sort_idx]),
        plot_samples(recon_masks_sorted, distances.numpy()[sort_idx]),
    )


def visualize_sort_compression(data_loader: DataLoader):
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        masks[i] = mask[0]
        distances[i] = compression_complexity(mask)

    sort_idx = torch.argsort(distances)
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    plt.xlabel("images")
    plt.ylabel("compression length")
    plt.savefig("shape_complexity/results/compression_plot.png")

    return plot_samples(masks_sorted, distances.numpy()[sort_idx])


def visualize_sort_fft(data_loader: DataLoader):
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        masks[i] = mask[0]
        distances[i] = fft_measure(mask)

    sort_idx = torch.argsort(distances)
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    plt.xlabel("images")
    plt.ylabel("mean unidirectional frequency")
    plt.savefig("shape_complexity/results/fft_plot.png")

    return plot_samples(masks_sorted, distances.numpy()[sort_idx])


def visualize_sort_mean_precision(models: list[nn.Module], data_loader: DataLoader):
    masks = torch.zeros((400, 1, 64, 64))
    precisions = torch.zeros((400,))

    for i, (mask, _) in enumerate(data_loader, 0):
        masks[i] = mask[0]
        precisions[i] = mean_precision(models, mask)

    sort_idx = torch.argsort(precisions)
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(precisions)), np.sort(precisions.numpy()))
    plt.xlabel("images")
    plt.ylabel("mean precision")
    plt.savefig("shape_complexity/results/mean_prec_plot.png")

    return plot_samples(masks_sorted, precisions.numpy()[sort_idx])


def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    masks_recon = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    diffs = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask)
        masks_recon[i] = mask_recon_grid
        masks[i] = mask[0]
        diffs[i] = diff

    sort_idx = np.argsort(np.array(diffs))
    recon_masks_sorted = masks_recon.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(diffs)), np.sort(diffs))
    plt.xlabel("images")
    plt.ylabel("pixelwise difference of reconstructions")
    plt.savefig("shape_complexity/results/px_diff_plot.png")

    return (
        plot_samples(masks_sorted, diffs[sort_idx]),
        plot_samples(recon_masks_sorted, diffs[sort_idx]),
    )


def visualize_sort_3dim(
    data_loader: DataLoader, model_gb: nn.Module, model_lb: nn.Module
):
    masks_recon = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    measures = torch.zeros((400, 3))
    for i, (mask, _) in enumerate(data_loader, 0):
        c_compress = compression_complexity(mask)
        c_fft = fft_measure(mask)
        # TODO: maybe exchange by diff or mean measure instead of precision
        c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
            model_gb, model_lb, mask
        )
        masks_recon[i] = mask_recon_grid
        masks[i] = mask[0]
        measures[i] = torch.tensor([c_compress, c_fft, c_vae])

    measures[:] /= measures.max(dim=0).values
    measure_norm = torch.linalg.vector_norm(measures, dim=1)

    fig = plt.figure()
    fig.clf()
    ax = fig.add_subplot(projection="3d")
    ax.scatter(measures[:, 0], measures[:, 1], measures[:, 2], marker="o")

    ax.set_xlabel("zlib compression")
    ax.set_ylabel("FFT ratio")
    ax.set_zlabel(f"VAE ratio {model_gb.bottleneck}/{model_lb.bottleneck}")
    plt.savefig("shape_complexity/results/3d_plot.png")
    plt.close()

    sort_idx = np.argsort(np.array(measure_norm))
    recon_masks_sorted = masks_recon.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    return (
        plot_samples(masks_sorted, measure_norm[sort_idx]),
        plot_samples(recon_masks_sorted, measure_norm[sort_idx]),
    )


def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
    sampler = RandomSampler(dataset, replacement=True, num_samples=400)
    data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)

    masks = torch.zeros((400, 3, 64, 192))
    complexities = torch.zeros((400,))
    diffs = []
    for i, (mask, _) in enumerate(data_loader, 0):
        complexity, _, _, diff, mask_recon_grid = complexity_measure(
            model_gb, model_lb, mask, save_preliminary=True
        )
        masks[i] = mask_recon_grid
        diffs.append(diff)
        complexities[i] = complexity

    sort_idx = np.argsort(np.array(complexities))
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(diffs)), np.sort(diffs))
    plt.xlabel("images")
    plt.ylabel("prec difference (L-H)")
    plt.savefig("shape_complexity/results/diff_plot.png")
    plt.clf()

    return plot_samples(masks_sorted, complexities[sort_idx])


def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    masks = torch.zeros((400, 3, 64, 192))
    complexities = torch.zeros((400,))
    complexities_lb = torch.zeros((400,))
    complexities_gb = torch.zeros((400,))
    diffs = []
    prec_lbs = []
    prec_gbs = []
    for i, (mask, _) in enumerate(data_loader, 0):
        (
            complexity,
            lb,
            gb,
            diff,
            prec_lb,
            prec_gb,
            mask_recon_grid,
        ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
        masks[i] = mask_recon_grid
        diffs.append(diff)
        prec_lbs.append(prec_lb)
        prec_gbs.append(prec_gb)
        complexities[i] = complexity
        complexities_lb[i] = lb
        complexities_gb[i] = gb

    sort_idx = np.argsort(np.array(complexities))
    sort_idx_lb = np.argsort(np.array(complexities_lb))
    sort_idx_gb = np.argsort(np.array(complexities_gb))
    masks_sorted = masks.numpy()[sort_idx]
    masks_sorted_lb = masks.numpy()[sort_idx_lb]
    masks_sorted_gb = masks.numpy()[sort_idx_gb]

    diff_sort_idx = np.argsort(diffs)
    # plt.savefig("shape_complexity/results/diff_plot.png")
    #    plt.clf

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(
        np.arange(len(prec_lbs)),
        np.array(prec_lbs)[diff_sort_idx],
        label=f"bottleneck {model_lb.bottleneck}",
    )
    ax1.plot(
        np.arange(len(prec_gbs)),
        np.array(prec_gbs)[diff_sort_idx],
        label=f"bottleneck {model_gb.bottleneck}",
    )
    ax2.plot(
        np.arange(len(diffs)),
        np.sort(diffs),
        color="red",
        label="prec difference (H - L)",
    )
    ax1.legend(loc="lower left")
    ax2.legend(loc="lower right")
    ax1.set_ylabel("precision")
    ax2.set_ylabel("prec difference (H-L)")
    plt.savefig("shape_complexity/results/prec_plot.png")
    plt.clf()

    fig = plot_samples(masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs.png")
    plt.close(fig)
    fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb])
    fig.savefig("shape_complexity/results/lb.png")
    plt.close(fig)
    fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb])
    fig.savefig("shape_complexity/results/gb.png")
    plt.close(fig)


def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    recon_masks = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    complexities = torch.zeros((400,))
    diffs = np.zeros((400,))
    prec_gbs = np.zeros((400,))
    prec_lbs = np.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        (
            complexity,
            _,
            _,
            diff,
            prec_lb,
            prec_gb,
            mask_recon_grid,
        ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
        recon_masks[i] = mask_recon_grid
        masks[i] = mask[0]
        diffs[i] = diff
        prec_gbs[i] = prec_gb
        prec_lbs[i] = prec_lb
        complexities[i] = complexity

    sort_idx = np.argsort(np.array(complexities))
    masks_sorted = masks.numpy()[sort_idx]
    recon_masks_sorted = recon_masks.numpy()[sort_idx]

    # group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
    # bin_edges = [-np.inf, 0.0, 0.05, np.inf]
    # bins = np.digitize(diffs, bins=bin_edges, right=True)

    # for i in range(bins.min(), bins.max() + 1):
    #     bin_idx = bins == i
    #     binned_prec_gb = prec_gbs[bin_idx]
    #     prec_mean = binned_prec_gb.mean()
    #     prec_idx = prec_gbs > prec_mean

    #     binned_masks_high = recon_masks[bin_idx & prec_idx]
    #     binned_masks_low = recon_masks[bin_idx & ~prec_idx]

    #     save_image(
    #         binned_masks_high,
    #         f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
    #         padding=10,
    #     )
    #     save_image(
    #         binned_masks_low,
    #         f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
    #         padding=10,
    #     )

    # diff_sort_idx = np.argsort(diffs)

    # fig, ax1 = plt.subplots()
    # ax2 = ax1.twinx()
    # ax1.plot(
    #     np.arange(len(prec_lbs)),
    #     np.array(prec_lbs)[diff_sort_idx],
    #     label=f"bottleneck {model_lb.bottleneck}",
    # )
    # ax1.plot(
    #     np.arange(len(prec_gbs)),
    #     np.array(prec_gbs)[diff_sort_idx],
    #     label=f"bottleneck {model_gb.bottleneck}",
    # )
    # ax2.plot(
    #     np.arange(len(diffs)),
    #     np.sort(diffs),
    #     color="red",
    #     label="prec difference (H - L)",
    # )
    # ax1.legend(loc="lower left")
    # ax2.legend(loc="lower right")
    # ax1.set_ylabel("precision")
    # ax2.set_ylabel("prec difference (H-L)")
    # ax1.set_xlabel("image")
    # plt.savefig("shape_complexity/results/prec_plot.png")
    # plt.tight_layout(pad=2)
    # plt.clf()

    fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs_recon.png")
    plt.close(fig)

    fig = plot_samples(masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs.png")
    plt.close(fig)


LR = 1e-3
EPOCHS = 10
LOAD_PRETRAINED = True


def main():
    bottlenecks = [4, 8, 16, 32]
    models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
    optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}

    data_loader, dataset = load_data()

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, test_size]
    )
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

    if LOAD_PRETRAINED:
        for i, model in models.items():
            model.load_state_dict(
                torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth")
            )
    else:
        for epoch in range(EPOCHS):
            for i, model in models.items():
                train(
                    epoch,
                    model=model,
                    optimizer=optimizers[i],
                    data_loader=train_loader,
                )

            test(epoch, models=list(models.values()), dataset=test_dataset)

        for bn in bottlenecks:
            if not os.path.exists("shape_complexity/trained"):
                os.makedirs("shape_complexity/trained")

            torch.save(
                models[bn].state_dict(),
                f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth",
            )

    test(0, models=list(models.values()), dataset=test_dataset, save_results=True)

    bn_gt = 32
    bn_lt = 8

    # for i in range(10):
    #     figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
    #     figure.savefig(
    #         f"shape_complexity/results/this_{bn_gt}_to_{bn_lt}_sample{i}.png"
    #     )
    #     figure.clear()
    #     plt.close(figure)

    # figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
    # figure.savefig(f"shape_complexity/results/sort_{bn_gt}_to_{bn_lt}.png")

    sampler = RandomSampler(dataset, replacement=True, num_samples=400)
    data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)

    visualize_sort_group(data_loader, models[bn_gt], models[bn_lt])
    # visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
    fig, fig_recon = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt])
    fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
    fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png")
    plt.close(fig)
    plt.close(fig_recon)
    fig = visualize_sort_mean_precision(list(models.values()), data_loader)
    fig.savefig(f"shape_complexity/results/sort_mean_prec.png")
    plt.close(fig)
    fig = visualize_sort_fft(data_loader)
    fig.savefig(f"shape_complexity/results/sort_fft.png")
    plt.close(fig)
    fig = visualize_sort_compression(data_loader)
    fig.savefig(f"shape_complexity/results/sort_compression.png")
    fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
    fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
    fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
    plt.close(fig)
    plt.close()
    # fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt])
    # fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png")
    # fig_recon.savefig(
    #     f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png"
    # )
    # plt.close(fig)
    # plt.close(fig_recon)


if __name__ == "__main__":
    main()