Skip to content
Snippets Groups Projects
shape_complexity.py 32.6 KiB
Newer Older
Markus Rothgänger's avatar
Markus Rothgänger committed
# from zlib import compress
from bz2 import compress

Markus Rothgänger's avatar
Markus Rothgänger committed
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import torch
import torch.nn.functional as F
Markus Rothgänger's avatar
Markus Rothgänger committed
from kornia.morphology import closing
Markus Rothgänger's avatar
Markus Rothgänger committed
from torch import Tensor, nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
Markus Rothgänger's avatar
Markus Rothgänger committed
from torchvision.utils import make_grid, save_image

device = torch.device("cuda")
Markus Rothgänger's avatar
Markus Rothgänger committed
matplotlib.use("Agg")

dx = [+1, 0, -1, 0]
dy = [0, +1, 0, -1]


# perform depth first search for each candidate/unlabeled region
# reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation
def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int):
    n_rows, n_cols = mask.shape
    if x < 0 or x == n_rows:
        return
    if y < 0 or y == n_cols:
        return
    if labels[x][y] or not mask[x][y]:
        return  # already labeled or not marked with 1 in image

    # mark the current cell
    labels[x][y] = current_label

    # recursively mark the neighbors
    for direction in range(4):
        dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)


def find_components(mask: npt.NDArray):
    label = 0

    n_rows, n_cols = mask.shape
    labels = np.zeros(mask.shape, dtype=np.int8)

    for i in range(n_rows):
        for j in range(n_cols):
            if not labels[i][j] and mask[i][j]:
                label += 1
                dfs(mask, i, j, labels, label)

    return labels


# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
def bbox(img):
    max_x, max_y = img.shape
    rows = np.any(img, axis=1)
    cols = np.any(img, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]

    rmin = rmin - 1 if rmin > 0 else rmin
    cmin = cmin - 1 if cmin > 0 else cmin
    rmax = rmax + 1 if rmax < max_x else rmax
    cmax = cmax + 1 if cmax < max_y else cmax

    return rmin, rmax, cmin, cmax


def extract_single_masks(labels: npt.NDArray):
    masks = []
    for l in range(labels.max() + 1):
        mask = (labels == l).astype(np.int8)
        rmin, rmax, cmin, cmax = bbox(mask)
        masks.append(mask[rmin : rmax + 1, cmin : cmax + 1])

    return masks


class VAE(nn.Module):
    """
    https://github.com/pytorch/examples/blob/main/vae/main.py
    """

    def __init__(self, bottleneck=2, image_dim=4096):
        super(VAE, self).__init__()

        self.bottleneck = bottleneck
        self.image_dim = image_dim

        self.prelim_encode = nn.Sequential(
            nn.Flatten(), nn.Linear(image_dim, 400), nn.ReLU()
        )
        self.encode_mu = nn.Sequential(nn.Linear(400, bottleneck))
        self.encode_logvar = nn.Sequential(nn.Linear(400, bottleneck))

        self.decode = nn.Sequential(
            nn.Linear(bottleneck, 400),
            nn.ReLU(),
            nn.Linear(400, image_dim),
            nn.Sigmoid(),
        )

    def encode(self, x):
        # h1 = F.relu(self.encode(x))
        # return self.encode_mu(h1), self.encode_logvar(h1)
        x = self.prelim_encode(x)
        return self.encode_mu(x), self.encode_logvar(x)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

Markus Rothgänger's avatar
Markus Rothgänger committed
    # Reconstruction + KL divergence losses summed over all elements and batch
    def loss(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum")
Markus Rothgänger's avatar
Markus Rothgänger committed
        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
Markus Rothgänger's avatar
Markus Rothgänger committed
        return BCE + KLD
Markus Rothgänger's avatar
Markus Rothgänger committed

class CONVVAE(nn.Module):
    def __init__(
Markus Rothgänger's avatar
Markus Rothgänger committed
        self,
        bottleneck=2,
Markus Rothgänger's avatar
Markus Rothgänger committed
    ):
        super(CONVVAE, self).__init__()

        self.bottleneck = bottleneck
Markus Rothgänger's avatar
Markus Rothgänger committed
        self.feature_dim = 6 * 6 * 64
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
        self.conv1 = nn.Sequential(
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.Conv2d(1, 16, 5),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 30x30x16
Markus Rothgänger's avatar
Markus Rothgänger committed
        )
        self.conv2 = nn.Sequential(
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 14x14x32
Markus Rothgänger's avatar
Markus Rothgänger committed
        )
        self.conv3 = nn.Sequential(
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> 6x6x64
Markus Rothgänger's avatar
Markus Rothgänger committed
        )
Markus Rothgänger's avatar
Markus Rothgänger committed
        # self.conv4 = nn.Sequential(
        #     nn.Conv2d(32, self.bottleneck, 5),
        #     nn.ReLU(),
        #     nn.MaxPool2d((2, 2), return_indices=True),  # -> 1x1xbottleneck
        # )
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
        self.encode_mu = nn.Sequential(
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.Flatten(),
            nn.Linear(self.feature_dim, self.bottleneck),
Markus Rothgänger's avatar
Markus Rothgänger committed
        )
        self.encode_logvar = nn.Sequential(
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
Markus Rothgänger's avatar
Markus Rothgänger committed
        )

Markus Rothgänger's avatar
Markus Rothgänger committed
        self.decode_linear = nn.Linear(self.bottleneck, self.feature_dim)
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
        # self.decode4 = nn.Sequential(
        #     nn.ConvTranspose2d(self.bottleneck, 32, 5),
        #     nn.ReLU(),
        # )
Markus Rothgänger's avatar
Markus Rothgänger committed
        self.decode3 = nn.Sequential(
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.ConvTranspose2d(64, 64, 2, stride=2),  # -> 12x12x64
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3),  # -> 14x14x32
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.ReLU(),
Markus Rothgänger's avatar
Markus Rothgänger committed
        )
        self.decode2 = nn.Sequential(
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.ConvTranspose2d(32, 32, 2, stride=2),  # -> 28x28x32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3),  # -> 30x30x16
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.ReLU(),
Markus Rothgänger's avatar
Markus Rothgänger committed
        )
        self.decode1 = nn.Sequential(
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.ConvTranspose2d(16, 16, 2, stride=2),  # -> 60x60x16
            nn.ReLU(),
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.ConvTranspose2d(16, 1, 5),  # -> 64x64x1
Markus Rothgänger's avatar
Markus Rothgänger committed
            nn.Sigmoid(),
        )

    def encode(self, x):
Markus Rothgänger's avatar
Markus Rothgänger committed
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
Markus Rothgänger's avatar
Markus Rothgänger committed
        # x, idx4 = self.conv4(x)
Markus Rothgänger's avatar
Markus Rothgänger committed
        mu = self.encode_mu(x)
        logvar = self.encode_logvar(x)

Markus Rothgänger's avatar
Markus Rothgänger committed
        return mu, logvar
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
    def decode(self, z: Tensor):
Markus Rothgänger's avatar
Markus Rothgänger committed
        z = self.decode_linear(z)
Markus Rothgänger's avatar
Markus Rothgänger committed
        z = z.view((-1, 64, 6, 6))
        # z = F.max_unpool2d(z, idx4, (2, 2))
        # z = self.decode4(z)
Markus Rothgänger's avatar
Markus Rothgänger committed
        z = self.decode3(z)
        z = self.decode2(z)
        z = self.decode1(z)
        # z = z.view(-1, 128, 1, 1)
        # return self.decode_conv(z)
        return z
Markus Rothgänger's avatar
Markus Rothgänger committed

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
Markus Rothgänger's avatar
Markus Rothgänger committed
        mu, logvar = self.encode(x)
Markus Rothgänger's avatar
Markus Rothgänger committed
        z = self.reparameterize(mu, logvar)
Markus Rothgänger's avatar
Markus Rothgänger committed
        return self.decode(z), mu, logvar
Markus Rothgänger's avatar
Markus Rothgänger committed

    def loss(self, recon_x, x, mu, logvar):
        """https://github.com/pytorch/examples/blob/main/vae/main.py"""
        BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return BCE + KLD
Markus Rothgänger's avatar
Markus Rothgänger committed
class CloseTransform:
Markus Rothgänger's avatar
Markus Rothgänger committed
    kernel = torch.ones(5, 5)

    def __init__(self, kernel=None):
        if kernel is not None:
            self.kernel = kernel

    def __call__(self, x):
Markus Rothgänger's avatar
Markus Rothgänger committed
        x = transforms.F.to_tensor(x)

Markus Rothgänger's avatar
Markus Rothgänger committed
        if len(x.shape) < 4:
Markus Rothgänger's avatar
Markus Rothgänger committed
            return transforms.F.to_pil_image(
                closing(x.unsqueeze(dim=0), self.kernel).squeeze(dim=0)
            )
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
        return transforms.F.to_pil_image(closing(x, self.kernel))
Markus Rothgänger's avatar
Markus Rothgänger committed


def load_data():
    transform = transforms.Compose(
        [
            transforms.Grayscale(),
Markus Rothgänger's avatar
Markus Rothgänger committed
            transforms.RandomApply([CloseTransform()], p=0.25),
            transforms.Resize(
                (64, 64), interpolation=transforms.InterpolationMode.BILINEAR
            ),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.ToTensor(),
        ]
    )

Markus Rothgänger's avatar
Markus Rothgänger committed
    trajectories = (
        []
Markus Rothgänger's avatar
Markus Rothgänger committed
        if False
Markus Rothgänger's avatar
Markus Rothgänger committed
        else [
            # "v3_subtle_iceberg_lettuce_nymph-6_203-2056",
            "v3_absolute_grape_changeling-16_2277-4441",
            "v3_content_squash_angel-3_16074-17640",
            "v3_smooth_kale_loch_ness_monster-1_4439-6272",
            "v3_cute_breadfruit_spirit-6_17090-19102",
            "v3_key_nectarine_spirit-2_7081-9747",
            "v3_subtle_iceberg_lettuce_nymph-6_3819-6049",
            "v3_juvenile_apple_angel-30_396415-398113",
            "v3_subtle_iceberg_lettuce_nymph-6_6100-8068",
        ]
    )

    datasets = []
    for trj in trajectories:
        datasets.append(
            ImageFolder(
                f"activation_vis/out/critic/masks/{trj}/0/4", transform=transform
            )
        )

Markus Rothgänger's avatar
Markus Rothgänger committed
    datasets.append(
        ImageFolder("shape_complexity/data/simple_shapes", transform=transform)
    )

    dataset = torch.utils.data.ConcatDataset(datasets)

Markus Rothgänger's avatar
Markus Rothgänger committed
    data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
Markus Rothgänger's avatar
Markus Rothgänger committed
def train(epoch, model: VAE or CONVVAE, optimizer, data_loader, log_interval=40):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
Markus Rothgänger's avatar
Markus Rothgänger committed
        loss = model.loss(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(data_loader.dataset),
                    100.0 * batch_idx / len(data_loader),
                    loss.item() / len(data),
                )
            )

    print(
        "====> Epoch: {} Average loss: {:.4f}".format(
            epoch, train_loss / len(data_loader.dataset)
        )
    )


Markus Rothgänger's avatar
Markus Rothgänger committed
def test(epoch, models: list[CONVVAE] or list[VAE], dataset, save_results=False):
    for model in models:
        model.eval()
    test_loss = [0 for _ in models]

    test_batch_size = 32
    sampler = RandomSampler(dataset, replacement=True, num_samples=64)
    test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=sampler)
    comp_data = None

    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)

            for j, model in enumerate(models):
                recon_batch, mu, logvar = model(data)
Markus Rothgänger's avatar
Markus Rothgänger committed
                test_loss[j] += model.loss(recon_batch, data, mu, logvar).item()

                if i == 0:
                    n = min(data.size(0), 20)
                    if comp_data == None:
                        comp_data = data[:n]
                    comp_data = torch.cat(
                        [comp_data, recon_batch.view(test_batch_size, 1, 64, 64)[:n]]
                    )

Markus Rothgänger's avatar
Markus Rothgänger committed
            if i == 0 and save_results:
                if not os.path.exists("results"):
                    os.makedirs("results")
                save_image(
                    comp_data.cpu(),
                    "results/reconstruction_" + str(epoch) + ".png",
                    nrow=min(data.size(0), 20),
                )

    for i, model in enumerate(models):
        test_loss[i] /= len(test_loader.dataset)
Markus Rothgänger's avatar
Markus Rothgänger committed
        print(f"====> Test set loss model {model.bottleneck}: {test_loss[i]:.4f}")

    if save_results:
        plt.figure()
        bns = [m.bottleneck for m in models]
        plt.plot(bns, test_loss)
        plt.xticks(bns)
        plt.savefig("shape_complexity/results/vae_graph.png")
        plt.clf()


def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
    model.eval()
    image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path)))
    labels = find_components(image[0])
    single_masks = extract_single_masks(labels)
    mask = transforms.F.to_tensor(
        transforms.F.resize(
            transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)),
            (64, 64),
        )
    )

    with torch.no_grad():
        mask = mask.to(device)
        recon_x, _, _ = model(mask)
        recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon
        mask_bits = mask.cpu().numpy() > 0

        TP = (mask_bits & recon_bits).sum()
        FP = (recon_bits & ~mask_bits).sum()
        FN = (mask_bits & ~recon_bits).sum()

        prec = TP / (TP + FP)
        rec = TP / (TP + FN)
        # loss = pixelwise_loss(recon_x, mask)
        comp_data = torch.cat(
            [mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)]
        )
        # print(f"mask loss: {loss:.4f}")

        return prec, rec, comp_data


def distance_measure(model: VAE, img: Tensor):
    model.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon, mean, _ = model(mask)
Markus Rothgänger's avatar
Markus Rothgänger committed
        # TODO: apply threshold here?!
        _, recon_mean, _ = model(recon)

        distance = torch.norm(mean - recon_mean, p=2)

        return (
            distance,
            make_grid(torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0),
Markus Rothgänger's avatar
Markus Rothgänger committed
def compression_complexity(img: Tensor):
    np_img = img[0].numpy()
    compressed = compress(np_img)

    return len(compressed)


def fft_measure(img: Tensor):
    np_img = img[0][0].numpy()
    fft = np.fft.fft2(np_img)

Markus Rothgänger's avatar
Markus Rothgänger committed
    fft_abs = np.abs(fft)

Markus Rothgänger's avatar
Markus Rothgänger committed
    n = fft.shape[0]
    pos_f_idx = n // 2
    df = np.fft.fftfreq(n=n)
Markus Rothgänger's avatar
Markus Rothgänger committed
    amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum()
    mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
    mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
    mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2))
Markus Rothgänger's avatar
Markus Rothgänger committed

    # mean frequency in range 0 to 0.5
    return mean_freq / 0.5

Markus Rothgänger's avatar
Markus Rothgänger committed

def complexity_measure(
    model_gb: nn.Module,
    model_lb: nn.Module,
    img: Tensor,
    epsilon=0.4,
    save_preliminary=False,
):
    model_gb.eval()
    model_lb.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon_gb, _, _ = model_gb(mask)
        recon_lb, _, _ = model_lb(mask)

        recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
        recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
        mask_bits = mask[0].cpu() > 0

        if save_preliminary:
            save_image(
                torch.stack(
                    [mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()]
                ).cpu(),
                f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
            )
            save_image(
                torch.stack(
                    [
                        (mask_bits & recon_bits_gb).float(),
                        (recon_bits_gb & ~mask_bits).float(),
                        (mask_bits & recon_bits_lb).float(),
                        (recon_bits_lb & ~mask_bits).float(),
                    ]
                ).cpu(),
                f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
            )

        tp_gb = (mask_bits & recon_bits_gb).sum()
        fp_gb = (recon_bits_gb & ~mask_bits).sum()
        tp_lb = (mask_bits & recon_bits_lb).sum()
        fp_lb = (recon_bits_lb & ~mask_bits).sum()

        prec_gb = tp_gb / (tp_gb + fp_gb)
        prec_lb = tp_lb / (tp_lb + fp_lb)
        complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
        complexity_lb = 1 - prec_lb
        complexity_gb = 1 - prec_gb
        #           1 - (0.4 - abs(0.4 - 0.7))    = 0.9
        #           1 - 0.7                       = 0.3

        return (
            complexity,
            complexity_lb,
            complexity_gb,
            prec_gb - prec_lb,
            prec_lb,
            prec_gb,
            make_grid(
                torch.stack(
                    [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                ).cpu(),
                nrow=3,
                padding=0,
            ),
        )


Markus Rothgänger's avatar
Markus Rothgänger committed
def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4):
    mask = img.to(device)
    mask_bits = mask[0].cpu() > 0

    precisions = np.zeros(len(models))

    for i, model in enumerate(models):
        recon_gb, _, _ = model(mask)

        recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon

        tp = (mask_bits & recon_bits).sum()
        fp = (recon_bits & ~mask_bits).sum()

        prec = tp / (tp + fp)
        precisions[i] = prec

    return 1 - precisions.mean()


Markus Rothgänger's avatar
Markus Rothgänger committed
def complexity_measure_diff(
Markus Rothgänger's avatar
Markus Rothgänger committed
    model_gb: nn.Module,
    model_lb: nn.Module,
    img: Tensor,
Markus Rothgänger's avatar
Markus Rothgänger committed
):
    model_gb.eval()
    model_lb.eval()

    with torch.no_grad():
        mask = img.to(device)
        recon_gb, _, _ = model_gb(mask)
        recon_lb, _, _ = model_lb(mask)

        diff = torch.abs((recon_gb - recon_lb).cpu().sum())

        return (
            diff,
            make_grid(
                torch.stack(
                    [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                ).cpu(),
                nrow=3,
                padding=0,
            ),
        )


def plot_samples(masks: Tensor, complexities: npt.NDArray):
    dpi = 150
    rows = cols = 20
    total = rows * cols
    n_samples, _, y, x = masks.shape

    extent = (0, x - 1, 0, y - 1)

    if total != n_samples:
        raise Exception("shape mismatch")

    fig = plt.figure(figsize=(32, 16), dpi=dpi)
    for idx in np.arange(n_samples):
        ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])

        plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
        ax.set_title(
            f"{complexities[idx]:.4f}",
            fontdict={"fontsize": 6, "color": "orange"},
            y=0.35,
        )

    fig.patch.set_facecolor("#292929")
    height_px = y * rows
    width_px = x * cols
    fig.set_size_inches(width_px / (dpi / 2), height_px / (dpi / 2), forward=True)
    fig.tight_layout(pad=0)

    return fig


def visualize_sort_mean(data_loader: DataLoader, model: VAE):
Markus Rothgänger's avatar
Markus Rothgänger committed
    recon_masks = torch.zeros((400, 3, 64, 128))
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        distance, mask_recon_grid = distance_measure(model, mask)
Markus Rothgänger's avatar
Markus Rothgänger committed
        masks[i] = mask[0]
        recon_masks[i] = mask_recon_grid
        distances[i] = distance

    sort_idx = torch.argsort(distances)
Markus Rothgänger's avatar
Markus Rothgänger committed
    recon_masks_sorted = recon_masks.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
Markus Rothgänger's avatar
Markus Rothgänger committed
    plt.xlabel("images")
    plt.ylabel("latent mean L2 distance")
    plt.savefig("shape_complexity/results/distance_plot.png")

    return (
        plot_samples(masks_sorted, distances.numpy()[sort_idx]),
        plot_samples(recon_masks_sorted, distances.numpy()[sort_idx]),
Markus Rothgänger's avatar
Markus Rothgänger committed
    )


Markus Rothgänger's avatar
Markus Rothgänger committed
def visualize_sort_compression(data_loader: DataLoader):
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        masks[i] = mask[0]
        distances[i] = compression_complexity(mask)

    sort_idx = torch.argsort(distances)
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    plt.xlabel("images")
    plt.ylabel("compression length")
    plt.savefig("shape_complexity/results/compression_plot.png")

    return plot_samples(masks_sorted, distances.numpy()[sort_idx])


def visualize_sort_fft(data_loader: DataLoader):
    masks = torch.zeros((400, 1, 64, 64))
    distances = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        masks[i] = mask[0]
        distances[i] = fft_measure(mask)

    sort_idx = torch.argsort(distances)
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    plt.xlabel("images")
Markus Rothgänger's avatar
Markus Rothgänger committed
    plt.ylabel("mean unidirectional frequency")
Markus Rothgänger's avatar
Markus Rothgänger committed
    plt.savefig("shape_complexity/results/fft_plot.png")

    return plot_samples(masks_sorted, distances.numpy()[sort_idx])


Markus Rothgänger's avatar
Markus Rothgänger committed
def visualize_sort_mean_precision(models: list[nn.Module], data_loader: DataLoader):
    masks = torch.zeros((400, 1, 64, 64))
    precisions = torch.zeros((400,))

    for i, (mask, _) in enumerate(data_loader, 0):
        masks[i] = mask[0]
        precisions[i] = mean_precision(models, mask)

    sort_idx = torch.argsort(precisions)
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(precisions)), np.sort(precisions.numpy()))
    plt.xlabel("images")
    plt.ylabel("mean precision")
    plt.savefig("shape_complexity/results/mean_prec_plot.png")

    return plot_samples(masks_sorted, precisions.numpy()[sort_idx])


Markus Rothgänger's avatar
Markus Rothgänger committed
def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    masks_recon = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    diffs = torch.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask)
        masks_recon[i] = mask_recon_grid
        masks[i] = mask[0]
        diffs[i] = diff

    sort_idx = np.argsort(np.array(diffs))
    recon_masks_sorted = masks_recon.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(diffs)), np.sort(diffs))
    plt.xlabel("images")
    plt.ylabel("pixelwise difference of reconstructions")
    plt.savefig("shape_complexity/results/px_diff_plot.png")

    return (
        plot_samples(masks_sorted, diffs[sort_idx]),
        plot_samples(recon_masks_sorted, diffs[sort_idx]),
Markus Rothgänger's avatar
Markus Rothgänger committed
    )
Markus Rothgänger's avatar
Markus Rothgänger committed
def visualize_sort_3dim(
    data_loader: DataLoader, model_gb: nn.Module, model_lb: nn.Module
):
    masks_recon = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    measures = torch.zeros((400, 3))
    for i, (mask, _) in enumerate(data_loader, 0):
        c_compress = compression_complexity(mask)
        c_fft = fft_measure(mask)
Markus Rothgänger's avatar
Markus Rothgänger committed
        # TODO: maybe exchange by diff or mean measure instead of precision
Markus Rothgänger's avatar
Markus Rothgänger committed
        c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
            model_gb, model_lb, mask
        )
        masks_recon[i] = mask_recon_grid
        masks[i] = mask[0]
        measures[i] = torch.tensor([c_compress, c_fft, c_vae])

    measures[:] /= measures.max(dim=0).values
    measure_norm = torch.linalg.vector_norm(measures, dim=1)

Markus Rothgänger's avatar
Markus Rothgänger committed
    fig = plt.figure()
Markus Rothgänger's avatar
Markus Rothgänger committed
    fig.clf()
Markus Rothgänger's avatar
Markus Rothgänger committed
    ax = fig.add_subplot(projection="3d")
Markus Rothgänger's avatar
Markus Rothgänger committed
    ax.scatter(measures[:, 0], measures[:, 1], measures[:, 2], marker="o")
Markus Rothgänger's avatar
Markus Rothgänger committed

    ax.set_xlabel("zlib compression")
    ax.set_ylabel("FFT ratio")
Markus Rothgänger's avatar
Markus Rothgänger committed
    ax.set_zlabel(f"VAE ratio {model_gb.bottleneck}/{model_lb.bottleneck}")
Markus Rothgänger's avatar
Markus Rothgänger committed
    plt.savefig("shape_complexity/results/3d_plot.png")
Markus Rothgänger's avatar
Markus Rothgänger committed
    plt.close()
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
    sort_idx = np.argsort(np.array(measure_norm))
    recon_masks_sorted = masks_recon.numpy()[sort_idx]
    masks_sorted = masks.numpy()[sort_idx]

    return (
        plot_samples(masks_sorted, measure_norm[sort_idx]),
        plot_samples(recon_masks_sorted, measure_norm[sort_idx]),
def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
    sampler = RandomSampler(dataset, replacement=True, num_samples=400)
    data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)

    masks = torch.zeros((400, 3, 64, 192))
    complexities = torch.zeros((400,))
    diffs = []
    for i, (mask, _) in enumerate(data_loader, 0):
        complexity, _, _, diff, mask_recon_grid = complexity_measure(
            model_gb, model_lb, mask, save_preliminary=True
        )
        masks[i] = mask_recon_grid
        diffs.append(diff)
        complexities[i] = complexity

    sort_idx = np.argsort(np.array(complexities))
    masks_sorted = masks.numpy()[sort_idx]

    plt.plot(np.arange(len(diffs)), np.sort(diffs))
Markus Rothgänger's avatar
Markus Rothgänger committed
    plt.xlabel("images")
    plt.ylabel("prec difference (L-H)")
    plt.savefig("shape_complexity/results/diff_plot.png")
Markus Rothgänger's avatar
Markus Rothgänger committed
    plt.clf()

    return plot_samples(masks_sorted, complexities[sort_idx])


def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    masks = torch.zeros((400, 3, 64, 192))
    complexities = torch.zeros((400,))
    complexities_lb = torch.zeros((400,))
    complexities_gb = torch.zeros((400,))
    diffs = []
    prec_lbs = []
    prec_gbs = []
    for i, (mask, _) in enumerate(data_loader, 0):
        (
            complexity,
            lb,
            gb,
            diff,
            prec_lb,
            prec_gb,
            mask_recon_grid,
        ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
        masks[i] = mask_recon_grid
        diffs.append(diff)
        prec_lbs.append(prec_lb)
        prec_gbs.append(prec_gb)
        complexities[i] = complexity
        complexities_lb[i] = lb
        complexities_gb[i] = gb

    sort_idx = np.argsort(np.array(complexities))
    sort_idx_lb = np.argsort(np.array(complexities_lb))
    sort_idx_gb = np.argsort(np.array(complexities_gb))
    masks_sorted = masks.numpy()[sort_idx]
    masks_sorted_lb = masks.numpy()[sort_idx_lb]
    masks_sorted_gb = masks.numpy()[sort_idx_gb]

    diff_sort_idx = np.argsort(diffs)
    # plt.savefig("shape_complexity/results/diff_plot.png")
    #    plt.clf

    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax1.plot(
Markus Rothgänger's avatar
Markus Rothgänger committed
        np.arange(len(prec_lbs)),
        np.array(prec_lbs)[diff_sort_idx],
        label=f"bottleneck {model_lb.bottleneck}",
    )
    ax1.plot(
        np.arange(len(prec_gbs)),
        np.array(prec_gbs)[diff_sort_idx],
        label=f"bottleneck {model_gb.bottleneck}",
    )
    ax2.plot(
        np.arange(len(diffs)),
        np.sort(diffs),
        color="red",
        label="prec difference (H - L)",
Markus Rothgänger's avatar
Markus Rothgänger committed
    ax1.legend(loc="lower left")
    ax2.legend(loc="lower right")
    ax1.set_ylabel("precision")
    ax2.set_ylabel("prec difference (H-L)")
    plt.savefig("shape_complexity/results/prec_plot.png")
    plt.clf()

    fig = plot_samples(masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs.png")
    plt.close(fig)
    fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb])
    fig.savefig("shape_complexity/results/lb.png")
    plt.close(fig)
    fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb])
    fig.savefig("shape_complexity/results/gb.png")
    plt.close(fig)

Markus Rothgänger's avatar
Markus Rothgänger committed

def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
    recon_masks = torch.zeros((400, 3, 64, 192))
    masks = torch.zeros((400, 1, 64, 64))
    complexities = torch.zeros((400,))
    diffs = np.zeros((400,))
    prec_gbs = np.zeros((400,))
    prec_lbs = np.zeros((400,))
    for i, (mask, _) in enumerate(data_loader, 0):
        (
            complexity,
            _,
            _,
            diff,
            prec_lb,
            prec_gb,
            mask_recon_grid,
        ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
        recon_masks[i] = mask_recon_grid
        masks[i] = mask[0]
        diffs[i] = diff
        prec_gbs[i] = prec_gb
        prec_lbs[i] = prec_lb
        complexities[i] = complexity

    sort_idx = np.argsort(np.array(complexities))
    masks_sorted = masks.numpy()[sort_idx]
    recon_masks_sorted = recon_masks.numpy()[sort_idx]

Markus Rothgänger's avatar
Markus Rothgänger committed
    # group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
    # bin_edges = [-np.inf, 0.0, 0.05, np.inf]
    # bins = np.digitize(diffs, bins=bin_edges, right=True)
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
    # for i in range(bins.min(), bins.max() + 1):
    #     bin_idx = bins == i
    #     binned_prec_gb = prec_gbs[bin_idx]
    #     prec_mean = binned_prec_gb.mean()
    #     prec_idx = prec_gbs > prec_mean
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
    #     binned_masks_high = recon_masks[bin_idx & prec_idx]
    #     binned_masks_low = recon_masks[bin_idx & ~prec_idx]
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
    #     save_image(
    #         binned_masks_high,
    #         f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
    #         padding=10,
    #     )
    #     save_image(
    #         binned_masks_low,
    #         f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
    #         padding=10,
    #     )
Markus Rothgänger's avatar
Markus Rothgänger committed

Markus Rothgänger's avatar
Markus Rothgänger committed
    # diff_sort_idx = np.argsort(diffs)

    # fig, ax1 = plt.subplots()
    # ax2 = ax1.twinx()
    # ax1.plot(
    #     np.arange(len(prec_lbs)),
    #     np.array(prec_lbs)[diff_sort_idx],
    #     label=f"bottleneck {model_lb.bottleneck}",
    # )
    # ax1.plot(
    #     np.arange(len(prec_gbs)),
    #     np.array(prec_gbs)[diff_sort_idx],
    #     label=f"bottleneck {model_gb.bottleneck}",
    # )
    # ax2.plot(
    #     np.arange(len(diffs)),
    #     np.sort(diffs),
    #     color="red",
    #     label="prec difference (H - L)",
    # )
    # ax1.legend(loc="lower left")
    # ax2.legend(loc="lower right")
    # ax1.set_ylabel("precision")
    # ax2.set_ylabel("prec difference (H-L)")
    # ax1.set_xlabel("image")
    # plt.savefig("shape_complexity/results/prec_plot.png")
    # plt.tight_layout(pad=2)
    # plt.clf()
Markus Rothgänger's avatar
Markus Rothgänger committed

    fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs_recon.png")
    plt.close(fig)

    fig = plot_samples(masks_sorted, complexities[sort_idx])
    fig.savefig("shape_complexity/results/abs.png")
    plt.close(fig)
Markus Rothgänger's avatar
Markus Rothgänger committed
EPOCHS = 10
Markus Rothgänger's avatar
Markus Rothgänger committed
LOAD_PRETRAINED = True
Markus Rothgänger's avatar
Markus Rothgänger committed
    bottlenecks = [4, 8, 16, 32]
Markus Rothgänger's avatar
Markus Rothgänger committed
    models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
    optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}

    data_loader, dataset = load_data()

Markus Rothgänger's avatar
Markus Rothgänger committed
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, test_size]
    )
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

Markus Rothgänger's avatar
Markus Rothgänger committed
    if LOAD_PRETRAINED:
        for i, model in models.items():
            model.load_state_dict(
                torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth")
            )
    else:
        for epoch in range(EPOCHS):
            for i, model in models.items():
                train(
                    epoch,
                    model=model,
                    optimizer=optimizers[i],
                    data_loader=train_loader,
                )

            test(epoch, models=list(models.values()), dataset=test_dataset)

        for bn in bottlenecks:
            if not os.path.exists("shape_complexity/trained"):
                os.makedirs("shape_complexity/trained")

            torch.save(
                models[bn].state_dict(),
                f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth",
            )

    test(0, models=list(models.values()), dataset=test_dataset, save_results=True)
Markus Rothgänger's avatar
Markus Rothgänger committed

    bn_gt = 32
Markus Rothgänger's avatar
Markus Rothgänger committed
    bn_lt = 8

    # for i in range(10):
    #     figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
    #     figure.savefig(
    #         f"shape_complexity/results/this_{bn_gt}_to_{bn_lt}_sample{i}.png"
    #     )
    #     figure.clear()
    #     plt.close(figure)

    # figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
    # figure.savefig(f"shape_complexity/results/sort_{bn_gt}_to_{bn_lt}.png")

    sampler = RandomSampler(dataset, replacement=True, num_samples=400)
    data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)

Markus Rothgänger's avatar
Markus Rothgänger committed
    visualize_sort_group(data_loader, models[bn_gt], models[bn_lt])
Markus Rothgänger's avatar
Markus Rothgänger committed
    # visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
Markus Rothgänger's avatar
Markus Rothgänger committed
    fig, fig_recon = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt])
    fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
    fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png")
    plt.close(fig)
    plt.close(fig_recon)
    fig = visualize_sort_mean_precision(list(models.values()), data_loader)
    fig.savefig(f"shape_complexity/results/sort_mean_prec.png")
    plt.close(fig)
Markus Rothgänger's avatar
Markus Rothgänger committed
    fig = visualize_sort_fft(data_loader)
    fig.savefig(f"shape_complexity/results/sort_fft.png")
Markus Rothgänger's avatar
Markus Rothgänger committed
    plt.close(fig)
Markus Rothgänger's avatar
Markus Rothgänger committed
    fig = visualize_sort_compression(data_loader)
    fig.savefig(f"shape_complexity/results/sort_compression.png")
Markus Rothgänger's avatar
Markus Rothgänger committed
    fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
    fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
    fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
    plt.close(fig)
    plt.close()
Markus Rothgänger's avatar
Markus Rothgänger committed
    # fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt])
    # fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png")
    # fig_recon.savefig(
    #     f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png"
    # )
Markus Rothgänger's avatar
Markus Rothgänger committed
    # plt.close(fig)
Markus Rothgänger's avatar
Markus Rothgänger committed
    # plt.close(fig_recon)