Skip to content
Snippets Groups Projects
shape_complexity.py 29.32 KiB
import os
from zlib import compress

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
import torch.nn.functional as F
from PIL import Image
from scipy.fft import fft
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torchvision.utils import make_grid, save_image

device = torch.device("cuda")
matplotlib.use("Agg")

dx = [+1, 0, -1, 0]
dy = [0, +1, 0, -1]


# perform depth first search for each candidate/unlabeled region
# reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation
def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int):
    n_rows, n_cols = mask.shape
    if x < 0 or x == n_rows:
        return
    if y < 0 or y == n_cols:
        return
    if labels[x][y] or not mask[x][y]:
        return  # already labeled or not marked with 1 in image

    # mark the current cell
    labels[x][y] = current_label

    # recursively mark the neighbors
    for direction in range(4):
        dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)


def find_components(mask: npt.NDArray):
    label = 0

    n_rows, n_cols = mask.shape
    labels = np.zeros(mask.shape, dtype=np.int8)

    for i in range(n_rows):
        for j in range(n_cols):
            if not labels[i][j] and mask[i][j]:
                label += 1
                dfs(mask, i, j, labels, label)

    return labels


# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
def bbox(img):
    max_x, max_y = img.shape
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    rmin = rmin - 1 if rmin > 0 else rmin
    cmin = cmin - 1 if cmin > 0 else cmin
    rmax = rmax + 1 if rmax < max_x else rmax
    cmax = cmax + 1 if cmax < max_y else cmax

    return rmin, rmax, cmin, cmax


def extract_single_masks(labels: npt.NDArray):
    masks = []
    for l in range(labels.max() + 1):
        mask = (labels == l).astype(np.int8)
        rmin, rmax, cmin, cmax = bbox(mask)
        masks.append(mask[rmin : rmax + 1, cmin : cmax + 1])

    return masks


class VAE(nn.Module):
    """
    https://github.com/pytorch/examples/blob/main/vae/main.py
    """

    def __init__(self, bottleneck=2, image_dim=4096):
        super(VAE, self).__init__()

        self.bottleneck = bottleneck
        self.image_dim = image_dim

        self.prelim_encode = nn.Sequential(
            nn.Flatten(), nn.Linear(image_dim, 400), nn.ReLU()
        )
        self.encode_mu = nn.Sequential(nn.Linear(400, bottleneck))
        self.encode_logvar = nn.Sequential(nn.Linear(400, bottleneck))

        self.decode = nn.Sequential(
            nn.Linear(bottleneck, 400),
            nn.ReLU(),
            nn.Linear(400, image_dim),
            nn.Sigmoid(),
        )

    def encode(self, x):
        # h1 = F.relu(self.encode(x))
        # return self.encode_mu(h1), self.encode_logvar(h1)
        x = self.prelim_encode(x)
        return self.encode_mu(x), self.encode_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    # Reconstruction + KL divergence losses summed over all elements and batch
    def loss(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum")

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD


class CONVVAE(nn.Module):
    def __init__(
        self,
        bottleneck=2,
    ):
        super(CONVVAE, self).__init__()

        self.bottleneck = bottleneck
        self.feature_dim = 6 * 6 * 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), return_indices=True),  # -> 30x30x16
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), return_indices=True),  # -> 14x14x32
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), return_indices=True),  # -> 6x6x64
        )
        # self.conv4 = nn.Sequential(
        #     nn.Conv2d(32, self.bottleneck, 5),
        #     nn.ReLU(),
        #     nn.MaxPool2d((2, 2), return_indices=True),  # -> 1x1xbottleneck
        # )

        self.encode_mu = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.feature_dim, self.bottleneck),
        )
        self.encode_logvar = nn.Sequential(
            nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
        )

        self.decode_linear = nn.Linear(self.bottleneck, self.feature_dim)

        # self.decode4 = nn.Sequential(
        #     nn.ConvTranspose2d(self.bottleneck, 32, 5),
        #     nn.ReLU(),
        # )
        self.decode3 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 3),
            nn.ReLU(),
        )
        self.decode2 = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 3),
            nn.ReLU(),
        )
        self.decode1 = nn.Sequential(
            nn.ConvTranspose2d(16, 1, 5),
            nn.Sigmoid(),
        )

    def encode(self, x):
        x, idx1 = self.conv1(x)
        x, idx2 = self.conv2(x)
        x, idx3 = self.conv3(x)
        # x, idx4 = self.conv4(x)
        mu = self.encode_mu(x)
        logvar = self.encode_logvar(x)

        return mu, logvar, (idx1, idx2, idx3)

    def decode(self, z: Tensor, indexes: tuple):
        (idx1, idx2, idx3) = indexes
        z = self.decode_linear(z)
        z = z.view((-1, 64, 6, 6))
        # z = F.max_unpool2d(z, idx4, (2, 2))
        # z = self.decode4(z)
        z = F.max_unpool2d(z, idx3, (2, 2))
        z = self.decode3(z)
        z = F.max_unpool2d(z, idx2, (2, 2))
        z = self.decode2(z)
        z = F.max_unpool2d(z, idx1, (2, 2))
        z = self.decode1(z)
        # z = z.view(-1, 128, 1, 1)
        # return self.decode_conv(z)
        return z

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar, indexes = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, indexes), mu, logvar

    def loss(self, recon_x, x, mu, logvar):
        """https://github.com/pytorch/examples/blob/main/vae/main.py"""
        BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD


def load_data():
    transform = transforms.Compose(
        [
            transforms.Grayscale(),
            transforms.Resize(
                (64, 64), interpolation=transforms.InterpolationMode.BILINEAR
            ),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
        ]
    )

    trajectories = [
        # "v3_subtle_iceberg_lettuce_nymph-6_203-2056",
        "v3_absolute_grape_changeling-16_2277-4441",
        "v3_content_squash_angel-3_16074-17640",
        "v3_smooth_kale_loch_ness_monster-1_4439-6272",
        "v3_cute_breadfruit_spirit-6_17090-19102",
        "v3_key_nectarine_spirit-2_7081-9747",
        "v3_subtle_iceberg_lettuce_nymph-6_3819-6049",
        "v3_juvenile_apple_angel-30_396415-398113",
        "v3_subtle_iceberg_lettuce_nymph-6_6100-8068",
    ]

    datasets = []
    for trj in trajectories:
        datasets.append(
            ImageFolder(
                f"activation_vis/out/critic/masks/{trj}/0/4", transform=transform
            )
        )

    dataset = torch.utils.data.ConcatDataset(datasets)

    data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
    return data_loader, dataset


def train(epoch, model: VAE or CONVVAE, optimizer, data_loader, log_interval=40):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = model.loss(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(data_loader.dataset),
                    100.0 * batch_idx / len(data_loader),
                    loss.item() / len(data),
                )
            )

    print(
        "====> Epoch: {} Average loss: {:.4f}".format(
            epoch, train_loss / len(data_loader.dataset)
        )
    )


def test(epoch, models, dataset):
    for model in models:
        model.eval()
    test_loss = [0 for _ in models]

    test_batch_size = 32
    sampler = RandomSampler(dataset, replacement=True, num_samples=64)
    test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=sampler)
    comp_data = None

    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)

            for j, model in enumerate(models):
                recon_batch, mu, logvar = model(data)
                test_loss[j] += model.loss(recon_batch, data, mu, logvar).item()

                if i == 0:
                    n = min(data.size(0), 20)
                    if comp_data == None:
                        comp_data = data[:n]
                    comp_data = torch.cat(
                        [comp_data, recon_batch.view(test_batch_size, 1, 64, 64)[:n]]
                    )

            if i == 0:
                if not os.path.exists("results"):
                    os.makedirs("results")
                save_image(
                    comp_data.cpu(),
                    "results/reconstruction_" + str(epoch) + ".png",
                    nrow=min(data.size(0), 20),
                )

    for i, model in enumerate(models):
        test_loss[i] /= len(test_loader.dataset)
        print(f"====> Test set loss model {i}: {test_loss[i]:.4f}")


def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
    model.eval()
    image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path)))
    labels = find_components(image[0])
    single_masks = extract_single_masks(labels)
    mask = transforms.F.to_tensor(
        transforms.F.resize(
            transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)),
            (64, 64),
        )
    )

    with torch.no_grad():
        mask = mask.to(device)
        recon_x, _, _ = model(mask)
        recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon
        mask_bits = mask.cpu().numpy() > 0

        TP = (mask_bits & recon_bits).sum()
        FP = (recon_bits & ~mask_bits).sum()
        FN = (mask_bits & ~recon_bits).sum()

        prec = TP / (TP + FP)
        rec = TP / (TP + FN)
        # loss = pixelwise_loss(recon_x, mask)
        comp_data = torch.cat(
            [mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)]
        )
        # print(f"mask loss: {loss:.4f}")

        return prec, rec, comp_data


def distance_measure(model: VAE, img: Tensor):
    model.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon, mean, _ = model(mask)
        # TODO: apply threshold here?!
        _, recon_mean, _ = model(recon)

        distance = torch.norm(mean - recon_mean, p=2)

        return distance, make_grid(
            torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0
        )


def compression_complexity(img: Tensor):
    np_img = img[0].numpy()
    compressed = compress(np_img)

    return len(compressed)


def fft_measure(img: Tensor):
    np_img = img[0][0].numpy()
    fft = np.fft.fft2(np_img)
    magnitude = np.fft.fftshift(np.abs(fft))
    spectrum = np.log(1 + magnitude)

    M, N = np_img.shape

    total_freq_value = spectrum.sum()
    inner_sum = spectrum[M // 3 : 2 * (M // 3), N // 3 : 2 * (N // 3)].sum()
    return (total_freq_value - inner_sum) / total_freq_value


def complexity_measure(
    model_gb: nn.Module,
    model_lb: nn.Module,
    img: Tensor,
    epsilon=0.4,
    save_preliminary=False,
):
    model_gb.eval()
    model_lb.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon_gb, _, _ = model_gb(mask)
        recon_lb, _, _ = model_lb(mask)

        recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
        recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
        mask_bits = mask[0].cpu() > 0

        if save_preliminary:
            save_image(
                torch.stack(
                    [mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()]
                ).cpu(),
                f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
            )
            save_image(
                torch.stack(
                    [
                        (mask_bits & recon_bits_gb).float(),
                        (recon_bits_gb & ~mask_bits).float(),
                        (mask_bits & recon_bits_lb).float(),
                        (recon_bits_lb & ~mask_bits).float(),
                    ]
                ).cpu(),
                f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
            )

        tp_gb = (mask_bits & recon_bits_gb).sum()
        fp_gb = (recon_bits_gb & ~mask_bits).sum()
        tp_lb = (mask_bits & recon_bits_lb).sum()
        fp_lb = (recon_bits_lb & ~mask_bits).sum()

        prec_gb = tp_gb / (tp_gb + fp_gb)
        prec_lb = tp_lb / (tp_lb + fp_lb)
        complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
        complexity_lb = 1 - prec_lb
        complexity_gb = 1 - prec_gb
        #           1 - (0.4 - abs(0.4 - 0.7))    = 0.9
        #           1 - 0.7                       = 0.3

        return (
            complexity,
            complexity_lb,
            complexity_gb,
            prec_gb - prec_lb,
            prec_lb,
            prec_gb,
            make_grid(
                torch.stack(
                    [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                ).cpu(),
                nrow=3,
                padding=0,
            ),
        )


def complexity_measure_diff(
    model_gb: nn.Module,
    model_lb: nn.Module,
    img: Tensor,
):
    model_gb.eval()
    model_lb.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon_gb, _, _ = model_gb(mask)
        recon_lb, _, _ = model_lb(mask)

        diff = torch.abs((recon_gb - recon_lb).cpu().sum())

        return (
            diff,
            make_grid(
                torch.stack(
                    [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                ).cpu(),
                nrow=3,
                padding=0,
            ),
        )


def plot_samples(masks: Tensor, complexities: npt.NDArray):
    dpi = 150
    rows = cols = 20
    total = rows * cols
    n_samples, _, y, x = masks.shape

    extent = (0, x - 1, 0, y - 1)

    if total != n_samples:
        raise Exception("shape mismatch")

    fig = plt.figure(figsize=(32, 16), dpi=dpi)
    for idx in np.arange(n_samples):
        ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])

        plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
        ax.set_title(
            f"{complexities[idx]:.4f}",
            fontdict={"fontsize": 6, "color": "orange"},
            y=0.35,
        )

    fig.patch.set_facecolor("#292929")
    height_px = y * rows
    width_px = x * cols
    fig.set_size_inches(width_px / (dpi / 2), height_px / (dpi / 2), forward=True)
    fig.tight_layout(pad=0)

    return fig


def visualize_sort_mean(data_loader: DataLoader, model: VAE):
    recon_masks = torch.zeros((400, 3, 64, 128))
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        distance, mask_recon_grid = distance_measure(model, mask)
        masks[i] = mask[0]
        recon_masks[i] = mask_recon_grid
        distances[i] = distance

    sort_idx = torch.argsort(distances)
    recon_masks_sorted = recon_masks.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    plt.xlabel("images")
    plt.ylabel("latent mean L2 distance")
    plt.savefig("shape_complexity/results/distance_plot.png")

    return plot_samples(masks_sorted, distances.numpy()[sort_idx]), plot_samples(
        recon_masks_sorted, distances.numpy()[sort_idx]
    )


def visualize_sort_compression(data_loader: DataLoader):
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        masks[i] = mask[0]
        distances[i] = compression_complexity(mask)

    sort_idx = torch.argsort(distances)
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    plt.xlabel("images")
    plt.ylabel("compression length")
    plt.savefig("shape_complexity/results/compression_plot.png")

    return plot_samples(masks_sorted, distances.numpy()[sort_idx])


def visualize_sort_fft(data_loader: DataLoader):
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        masks[i] = mask[0]
        distances[i] = fft_measure(mask)

    sort_idx = torch.argsort(distances)
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    plt.xlabel("images")
    plt.ylabel("compression length")
    plt.savefig("shape_complexity/results/fft_plot.png")

    return plot_samples(masks_sorted, distances.numpy()[sort_idx])


def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    masks_recon = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    diffs = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask)
        masks_recon[i] = mask_recon_grid
        masks[i] = mask[0]
        diffs[i] = diff

    sort_idx = np.argsort(np.array(diffs))
    recon_masks_sorted = masks_recon.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(diffs)), np.sort(diffs))
    plt.xlabel("images")
    plt.ylabel("pixelwise difference of reconstructions")
    plt.savefig("shape_complexity/results/px_diff_plot.png")

    return plot_samples(masks_sorted, diffs[sort_idx]), plot_samples(
        recon_masks_sorted, diffs[sort_idx]
    )


def visualize_sort_3dim(
    data_loader: DataLoader, model_gb: nn.Module, model_lb: nn.Module
):
    masks_recon = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    measures = torch.zeros((400, 3))
    for i, (mask, _) in enumerate(data_loader, 0):
        c_compress = compression_complexity(mask)
        c_fft = fft_measure(mask)
        # TODO: maybe exchange by diff measure instead of precision
        c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
            model_gb, model_lb, mask
        )
        masks_recon[i] = mask_recon_grid
        masks[i] = mask[0]
        measures[i] = torch.tensor([c_compress, c_fft, c_vae])

    measures[:] /= measures.max(dim=0).values
    measure_norm = torch.linalg.vector_norm(measures, dim=1)

    sort_idx = np.argsort(np.array(measure_norm))
    recon_masks_sorted = masks_recon.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    # TODO: add 3d plot of measures

    return plot_samples(masks_sorted, measure_norm[sort_idx]), plot_samples(
        recon_masks_sorted, measure_norm[sort_idx]
    )


def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
    sampler = RandomSampler(dataset, replacement=True, num_samples=400)
    data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)

    masks = torch.zeros((400, 3, 64, 192))
    complexities = torch.zeros((400,))
    diffs = []
    for i, (mask, _) in enumerate(data_loader, 0):
        complexity, _, _, diff, mask_recon_grid = complexity_measure(
            model_gb, model_lb, mask, save_preliminary=True
        )
        masks[i] = mask_recon_grid
        diffs.append(diff)
        complexities[i] = complexity

    sort_idx = np.argsort(np.array(complexities))
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(diffs)), np.sort(diffs))
    plt.xlabel("images")
    plt.ylabel("prec difference (L-H)")
    plt.savefig("shape_complexity/results/diff_plot.png")

    return plot_samples(masks_sorted, complexities[sort_idx])


def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    masks = torch.zeros((400, 3, 64, 192))
    complexities = torch.zeros((400,))
    complexities_lb = torch.zeros((400,))
    complexities_gb = torch.zeros((400,))
    diffs = []
    prec_lbs = []
    prec_gbs = []
    for i, (mask, _) in enumerate(data_loader, 0):
        (
            complexity,
            lb,
            gb,
            diff,
            prec_lb,
            prec_gb,
            mask_recon_grid,
        ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
        masks[i] = mask_recon_grid
        diffs.append(diff)
        prec_lbs.append(prec_lb)
        prec_gbs.append(prec_gb)
        complexities[i] = complexity
        complexities_lb[i] = lb
        complexities_gb[i] = gb

    sort_idx = np.argsort(np.array(complexities))
    sort_idx_lb = np.argsort(np.array(complexities_lb))
    sort_idx_gb = np.argsort(np.array(complexities_gb))
    masks_sorted = masks.numpy()[sort_idx]
    masks_sorted_lb = masks.numpy()[sort_idx_lb]
    masks_sorted_gb = masks.numpy()[sort_idx_gb]

    diff_sort_idx = np.argsort(diffs)
    # plt.savefig("shape_complexity/results/diff_plot.png")
    #    plt.clf

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(
        np.arange(len(prec_lbs)),
        np.array(prec_lbs)[diff_sort_idx],
        label=f"bottleneck {model_lb.bottleneck}",
    )
    ax1.plot(
        np.arange(len(prec_gbs)),
        np.array(prec_gbs)[diff_sort_idx],
        label=f"bottleneck {model_gb.bottleneck}",
    )
    ax2.plot(
        np.arange(len(diffs)),
        np.sort(diffs),
        color="red",
        label="prec difference (H - L)",
    )
    ax1.legend(loc="lower left")
    ax2.legend(loc="lower right")
    ax1.set_ylabel("precision")
    ax2.set_ylabel("prec difference (H-L)")
    plt.savefig("shape_complexity/results/prec_plot.png")
    plt.clf()

    fig = plot_samples(masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs.png")
    plt.close(fig)
    fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb])
    fig.savefig("shape_complexity/results/lb.png")
    plt.close(fig)
    fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb])
    fig.savefig("shape_complexity/results/gb.png")
    plt.close(fig)


def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    recon_masks = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    complexities = torch.zeros((400,))
    diffs = np.zeros((400,))
    prec_gbs = np.zeros((400,))
    prec_lbs = np.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        (
            complexity,
            _,
            _,
            diff,
            prec_lb,
            prec_gb,
            mask_recon_grid,
        ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
        recon_masks[i] = mask_recon_grid
        masks[i] = mask[0]
        diffs[i] = diff
        prec_gbs[i] = prec_gb
        prec_lbs[i] = prec_lb
        complexities[i] = complexity

    sort_idx = np.argsort(np.array(complexities))
    masks_sorted = masks.numpy()[sort_idx]
    recon_masks_sorted = recon_masks.numpy()[sort_idx]

    group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
    bin_edges = [-np.inf, 0.0, 0.05, np.inf]
    bins = np.digitize(diffs, bins=bin_edges, right=True)

    for i in range(bins.min(), bins.max() + 1):
        bin_idx = bins == i
        binned_prec_gb = prec_gbs[bin_idx]
        prec_mean = binned_prec_gb.mean()
        prec_idx = prec_gbs > prec_mean

        binned_masks_high = recon_masks[bin_idx & prec_idx]
        binned_masks_low = recon_masks[bin_idx & ~prec_idx]

        save_image(
            binned_masks_high,
            f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
            padding=10,
        )
        save_image(
            binned_masks_low,
            f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
            padding=10,
        )

    diff_sort_idx = np.argsort(diffs)

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(
        np.arange(len(prec_lbs)),
        np.array(prec_lbs)[diff_sort_idx],
        label=f"bottleneck {model_lb.bottleneck}",
    )
    ax1.plot(
        np.arange(len(prec_gbs)),
        np.array(prec_gbs)[diff_sort_idx],
        label=f"bottleneck {model_gb.bottleneck}",
    )
    ax2.plot(
        np.arange(len(diffs)),
        np.sort(diffs),
        color="red",
        label="prec difference (H - L)",
    )
    ax1.legend(loc="lower left")
    ax2.legend(loc="lower right")
    ax1.set_ylabel("precision")
    ax2.set_ylabel("prec difference (H-L)")
    ax1.set_xlabel("image")
    plt.savefig("shape_complexity/results/prec_plot.png")
    plt.tight_layout(pad=2)
    plt.clf()

    fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs_recon.png")
    plt.close(fig)

    fig = plot_samples(masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs.png")
    plt.close(fig)


LR = 1e-3
EPOCHS = 10
LOAD_PRETRAINED = True


def main():
    bottlenecks = [1, 16]
    models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
    optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}

    data_loader, dataset = load_data()

    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, test_size]
    )
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

    if LOAD_PRETRAINED:
        for i, model in models.items():
            model.load_state_dict(
                torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth")
            )
    else:
        for epoch in range(EPOCHS):
            for i, model in models.items():
                train(
                    epoch,
                    model=model,
                    optimizer=optimizers[i],
                    data_loader=train_loader,
                )

            test(epoch, models=list(models.values()), dataset=test_dataset)

        for bn in bottlenecks:
            if not os.path.exists("shape_complexity/trained"):
                os.makedirs("shape_complexity/trained")

            torch.save(
                models[bn].state_dict(),
                f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth",
            )

    bn_gt = 16
    bn_lt = 1

    # for i in range(10):
    #     figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
    #     figure.savefig(
    #         f"shape_complexity/results/this_{bn_gt}_to_{bn_lt}_sample{i}.png"
    #     )
    #     figure.clear()
    #     plt.close(figure)

    # figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
    # figure.savefig(f"shape_complexity/results/sort_{bn_gt}_to_{bn_lt}.png")

    sampler = RandomSampler(dataset, replacement=True, num_samples=400)
    data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)

    visualize_sort_group(data_loader, models[bn_gt], models[bn_lt])
    # visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
    fig, _ = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt])
    fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
    fig = visualize_sort_fft(data_loader)
    fig.savefig(f"shape_complexity/results/sort_fft.png")
    fig = visualize_sort_compression(data_loader)
    fig.savefig(f"shape_complexity/results/sort_compression.png")
    fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
    fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
    fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
    plt.close(fig)
    plt.close(fig_recon)
    fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt])
    fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png")
    fig_recon.savefig(
        f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png"
    )
    plt.close(fig)
    plt.close(fig_recon)


if __name__ == "__main__":
    main()