Skip to content
Snippets Groups Projects
shape_complexity.py 32.6 KiB
Newer Older
  • Learn to ignore specific revisions
  • Markus Rothgänger's avatar
    Markus Rothgänger committed
    # from zlib import compress
    from bz2 import compress
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    import matplotlib
    
    import matplotlib.pyplot as plt
    import numpy as np
    import numpy.typing as npt
    import torch
    import torch.nn.functional as F
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    from kornia.morphology import closing
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    from torch import Tensor, nn
    
    from torch.optim import Adam
    from torch.utils.data import DataLoader, RandomSampler
    from torchvision.datasets import ImageFolder
    from torchvision.transforms import transforms
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    from torchvision.utils import make_grid, save_image
    
    
    device = torch.device("cuda")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    matplotlib.use("Agg")
    
    
    dx = [+1, 0, -1, 0]
    dy = [0, +1, 0, -1]
    
    
    # perform depth first search for each candidate/unlabeled region
    # reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation
    def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int):
        n_rows, n_cols = mask.shape
        if x < 0 or x == n_rows:
            return
        if y < 0 or y == n_cols:
            return
        if labels[x][y] or not mask[x][y]:
            return  # already labeled or not marked with 1 in image
    
        # mark the current cell
        labels[x][y] = current_label
    
        # recursively mark the neighbors
        for direction in range(4):
            dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)
    
    
    def find_components(mask: npt.NDArray):
        label = 0
    
        n_rows, n_cols = mask.shape
        labels = np.zeros(mask.shape, dtype=np.int8)
    
        for i in range(n_rows):
            for j in range(n_cols):
                if not labels[i][j] and mask[i][j]:
                    label += 1
                    dfs(mask, i, j, labels, label)
    
        return labels
    
    
    # https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
    def bbox(img):
        max_x, max_y = img.shape
        rows = np.any(img, axis=1)
        cols = np.any(img, axis=0)
        rmin, rmax = np.where(rows)[0][[0, -1]]
        cmin, cmax = np.where(cols)[0][[0, -1]]
    
        rmin = rmin - 1 if rmin > 0 else rmin
        cmin = cmin - 1 if cmin > 0 else cmin
        rmax = rmax + 1 if rmax < max_x else rmax
        cmax = cmax + 1 if cmax < max_y else cmax
    
        return rmin, rmax, cmin, cmax
    
    
    def extract_single_masks(labels: npt.NDArray):
        masks = []
        for l in range(labels.max() + 1):
            mask = (labels == l).astype(np.int8)
            rmin, rmax, cmin, cmax = bbox(mask)
            masks.append(mask[rmin : rmax + 1, cmin : cmax + 1])
    
        return masks
    
    
    class VAE(nn.Module):
        """
        https://github.com/pytorch/examples/blob/main/vae/main.py
        """
    
        def __init__(self, bottleneck=2, image_dim=4096):
            super(VAE, self).__init__()
    
            self.bottleneck = bottleneck
            self.image_dim = image_dim
    
            self.prelim_encode = nn.Sequential(
                nn.Flatten(), nn.Linear(image_dim, 400), nn.ReLU()
            )
            self.encode_mu = nn.Sequential(nn.Linear(400, bottleneck))
            self.encode_logvar = nn.Sequential(nn.Linear(400, bottleneck))
    
            self.decode = nn.Sequential(
                nn.Linear(bottleneck, 400),
                nn.ReLU(),
                nn.Linear(400, image_dim),
                nn.Sigmoid(),
            )
    
        def encode(self, x):
            # h1 = F.relu(self.encode(x))
            # return self.encode_mu(h1), self.encode_logvar(h1)
            x = self.prelim_encode(x)
            return self.encode_mu(x), self.encode_logvar(x)
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
    
        def forward(self, x):
            mu, logvar = self.encode(x)
            z = self.reparameterize(mu, logvar)
            return self.decode(z), mu, logvar
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        # Reconstruction + KL divergence losses summed over all elements and batch
        def loss(self, recon_x, x, mu, logvar):
            BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            # see Appendix B from VAE paper:
            # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
            # https://arxiv.org/abs/1312.6114
            # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            return BCE + KLD
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    class CONVVAE(nn.Module):
        def __init__(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self,
            bottleneck=2,
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        ):
            super(CONVVAE, self).__init__()
    
            self.bottleneck = bottleneck
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.feature_dim = 6 * 6 * 64
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.conv1 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Conv2d(1, 16, 5),
                nn.ReLU(),
                nn.MaxPool2d((2, 2)),  # -> 30x30x16
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.conv2 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Conv2d(16, 32, 3),
                nn.ReLU(),
                nn.MaxPool2d((2, 2)),  # -> 14x14x32
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.conv3 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Conv2d(32, 64, 3),
                nn.ReLU(),
                nn.MaxPool2d((2, 2)),  # -> 6x6x64
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            # self.conv4 = nn.Sequential(
            #     nn.Conv2d(32, self.bottleneck, 5),
            #     nn.ReLU(),
            #     nn.MaxPool2d((2, 2), return_indices=True),  # -> 1x1xbottleneck
            # )
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.encode_mu = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Flatten(),
                nn.Linear(self.feature_dim, self.bottleneck),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.encode_logvar = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.decode_linear = nn.Linear(self.bottleneck, self.feature_dim)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            # self.decode4 = nn.Sequential(
            #     nn.ConvTranspose2d(self.bottleneck, 32, 5),
            #     nn.ReLU(),
            # )
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.decode3 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ConvTranspose2d(64, 64, 2, stride=2),  # -> 12x12x64
                nn.ReLU(),
                nn.ConvTranspose2d(64, 32, 3),  # -> 14x14x32
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ReLU(),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.decode2 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ConvTranspose2d(32, 32, 2, stride=2),  # -> 28x28x32
                nn.ReLU(),
                nn.ConvTranspose2d(32, 16, 3),  # -> 30x30x16
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ReLU(),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.decode1 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ConvTranspose2d(16, 16, 2, stride=2),  # -> 60x60x16
                nn.ReLU(),
                nn.ConvTranspose2d(16, 1, 5),  # -> 64x64x1
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Sigmoid(),
            )
    
        def encode(self, x):
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            x = self.conv1(x)
            x = self.conv2(x)
            x = self.conv3(x)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            # x, idx4 = self.conv4(x)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            mu = self.encode_mu(x)
            logvar = self.encode_logvar(x)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            return mu, logvar
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        def decode(self, z: Tensor):
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            z = self.decode_linear(z)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            z = z.view((-1, 64, 6, 6))
            # z = F.max_unpool2d(z, idx4, (2, 2))
            # z = self.decode4(z)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            z = self.decode3(z)
            z = self.decode2(z)
            z = self.decode1(z)
            # z = z.view(-1, 128, 1, 1)
            # return self.decode_conv(z)
            return z
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
    
        def forward(self, x):
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            mu, logvar = self.encode(x)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            z = self.reparameterize(mu, logvar)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            return self.decode(z), mu, logvar
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
        def loss(self, recon_x, x, mu, logvar):
            """https://github.com/pytorch/examples/blob/main/vae/main.py"""
            BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
    
            # see Appendix B from VAE paper:
            # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
            # https://arxiv.org/abs/1312.6114
            # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
            return BCE + KLD
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    class CloseTransform:
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        kernel = torch.ones(5, 5)
    
        def __init__(self, kernel=None):
            if kernel is not None:
                self.kernel = kernel
    
        def __call__(self, x):
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            x = transforms.F.to_tensor(x)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            if len(x.shape) < 4:
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                return transforms.F.to_pil_image(
                    closing(x.unsqueeze(dim=0), self.kernel).squeeze(dim=0)
                )
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            return transforms.F.to_pil_image(closing(x, self.kernel))
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    
    def load_data():
        transform = transforms.Compose(
            [
                transforms.Grayscale(),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                transforms.RandomApply([CloseTransform()], p=0.25),
    
                transforms.Resize(
                    (64, 64), interpolation=transforms.InterpolationMode.BILINEAR
                ),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),
            ]
        )
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        trajectories = (
            []
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            if False
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            else [
                # "v3_subtle_iceberg_lettuce_nymph-6_203-2056",
                "v3_absolute_grape_changeling-16_2277-4441",
                "v3_content_squash_angel-3_16074-17640",
                "v3_smooth_kale_loch_ness_monster-1_4439-6272",
                "v3_cute_breadfruit_spirit-6_17090-19102",
                "v3_key_nectarine_spirit-2_7081-9747",
                "v3_subtle_iceberg_lettuce_nymph-6_3819-6049",
                "v3_juvenile_apple_angel-30_396415-398113",
                "v3_subtle_iceberg_lettuce_nymph-6_6100-8068",
            ]
        )
    
    
        datasets = []
        for trj in trajectories:
            datasets.append(
                ImageFolder(
                    f"activation_vis/out/critic/masks/{trj}/0/4", transform=transform
                )
            )
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        datasets.append(
            ImageFolder("shape_complexity/data/simple_shapes", transform=transform)
        )
    
    
        dataset = torch.utils.data.ConcatDataset(datasets)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def train(epoch, model: VAE or CONVVAE, optimizer, data_loader, log_interval=40):
    
        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(data_loader):
            data = data.to(device)
    
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            loss = model.loss(recon_batch, data, mu, logvar)
    
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
    
            if batch_idx % log_interval == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(data_loader.dataset),
                        100.0 * batch_idx / len(data_loader),
                        loss.item() / len(data),
                    )
                )
    
        print(
            "====> Epoch: {} Average loss: {:.4f}".format(
                epoch, train_loss / len(data_loader.dataset)
            )
        )
    
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def test(epoch, models: list[CONVVAE] or list[VAE], dataset, save_results=False):
    
        for model in models:
            model.eval()
        test_loss = [0 for _ in models]
    
        test_batch_size = 32
        sampler = RandomSampler(dataset, replacement=True, num_samples=64)
        test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=sampler)
        comp_data = None
    
        with torch.no_grad():
            for i, (data, _) in enumerate(test_loader):
                data = data.to(device)
    
                for j, model in enumerate(models):
                    recon_batch, mu, logvar = model(data)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                    test_loss[j] += model.loss(recon_batch, data, mu, logvar).item()
    
    
                    if i == 0:
                        n = min(data.size(0), 20)
                        if comp_data == None:
                            comp_data = data[:n]
                        comp_data = torch.cat(
                            [comp_data, recon_batch.view(test_batch_size, 1, 64, 64)[:n]]
                        )
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                if i == 0 and save_results:
    
                    if not os.path.exists("results"):
                        os.makedirs("results")
                    save_image(
                        comp_data.cpu(),
                        "results/reconstruction_" + str(epoch) + ".png",
                        nrow=min(data.size(0), 20),
                    )
    
        for i, model in enumerate(models):
            test_loss[i] /= len(test_loader.dataset)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            print(f"====> Test set loss model {model.bottleneck}: {test_loss[i]:.4f}")
    
        if save_results:
            plt.figure()
            bns = [m.bottleneck for m in models]
            plt.plot(bns, test_loss)
            plt.xticks(bns)
            plt.savefig("shape_complexity/results/vae_graph.png")
            plt.clf()
    
    
    
    def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
        model.eval()
        image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path)))
        labels = find_components(image[0])
        single_masks = extract_single_masks(labels)
        mask = transforms.F.to_tensor(
            transforms.F.resize(
                transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)),
                (64, 64),
            )
        )
    
        with torch.no_grad():
            mask = mask.to(device)
            recon_x, _, _ = model(mask)
            recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon
            mask_bits = mask.cpu().numpy() > 0
    
            TP = (mask_bits & recon_bits).sum()
            FP = (recon_bits & ~mask_bits).sum()
            FN = (mask_bits & ~recon_bits).sum()
    
            prec = TP / (TP + FP)
            rec = TP / (TP + FN)
            # loss = pixelwise_loss(recon_x, mask)
            comp_data = torch.cat(
                [mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)]
            )
            # print(f"mask loss: {loss:.4f}")
    
            return prec, rec, comp_data
    
    
    def distance_measure(model: VAE, img: Tensor):
        model.eval()
    
        with torch.no_grad():
            mask = img.to(device)
            recon, mean, _ = model(mask)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            # TODO: apply threshold here?!
            _, recon_mean, _ = model(recon)
    
    
            distance = torch.norm(mean - recon_mean, p=2)
    
    
            return (
                distance,
                make_grid(torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def compression_complexity(img: Tensor):
        np_img = img[0].numpy()
        compressed = compress(np_img)
    
        return len(compressed)
    
    
    def fft_measure(img: Tensor):
        np_img = img[0][0].numpy()
        fft = np.fft.fft2(np_img)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        fft_abs = np.abs(fft)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        n = fft.shape[0]
        pos_f_idx = n // 2
        df = np.fft.fftfreq(n=n)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum()
        mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
        mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2))
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
        # mean frequency in range 0 to 0.5
        return mean_freq / 0.5
    
    
    def complexity_measure(
        model_gb: nn.Module,
        model_lb: nn.Module,
        img: Tensor,
        epsilon=0.4,
        save_preliminary=False,
    ):
        model_gb.eval()
        model_lb.eval()
    
        with torch.no_grad():
            mask = img.to(device)
            recon_gb, _, _ = model_gb(mask)
            recon_lb, _, _ = model_lb(mask)
    
            recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
            recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
            mask_bits = mask[0].cpu() > 0
    
            if save_preliminary:
                save_image(
                    torch.stack(
                        [mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()]
                    ).cpu(),
                    f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
                )
                save_image(
                    torch.stack(
                        [
                            (mask_bits & recon_bits_gb).float(),
                            (recon_bits_gb & ~mask_bits).float(),
                            (mask_bits & recon_bits_lb).float(),
                            (recon_bits_lb & ~mask_bits).float(),
                        ]
                    ).cpu(),
                    f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
                )
    
            tp_gb = (mask_bits & recon_bits_gb).sum()
            fp_gb = (recon_bits_gb & ~mask_bits).sum()
            tp_lb = (mask_bits & recon_bits_lb).sum()
            fp_lb = (recon_bits_lb & ~mask_bits).sum()
    
            prec_gb = tp_gb / (tp_gb + fp_gb)
            prec_lb = tp_lb / (tp_lb + fp_lb)
            complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
            complexity_lb = 1 - prec_lb
            complexity_gb = 1 - prec_gb
            #           1 - (0.4 - abs(0.4 - 0.7))    = 0.9
            #           1 - 0.7                       = 0.3
    
            return (
                complexity,
                complexity_lb,
                complexity_gb,
                prec_gb - prec_lb,
                prec_lb,
                prec_gb,
                make_grid(
                    torch.stack(
                        [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                    ).cpu(),
                    nrow=3,
                    padding=0,
                ),
            )
    
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4):
        mask = img.to(device)
        mask_bits = mask[0].cpu() > 0
    
        precisions = np.zeros(len(models))
    
        for i, model in enumerate(models):
            recon_gb, _, _ = model(mask)
    
            recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon
    
            tp = (mask_bits & recon_bits).sum()
            fp = (recon_bits & ~mask_bits).sum()
    
            prec = tp / (tp + fp)
            precisions[i] = prec
    
        return 1 - precisions.mean()
    
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def complexity_measure_diff(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        model_gb: nn.Module,
        model_lb: nn.Module,
        img: Tensor,
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    ):
        model_gb.eval()
        model_lb.eval()
    
        with torch.no_grad():
            mask = img.to(device)
            recon_gb, _, _ = model_gb(mask)
            recon_lb, _, _ = model_lb(mask)
    
            diff = torch.abs((recon_gb - recon_lb).cpu().sum())
    
            return (
                diff,
                make_grid(
                    torch.stack(
                        [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                    ).cpu(),
                    nrow=3,
                    padding=0,
                ),
            )
    
    
    
    def plot_samples(masks: Tensor, complexities: npt.NDArray):
        dpi = 150
        rows = cols = 20
        total = rows * cols
        n_samples, _, y, x = masks.shape
    
        extent = (0, x - 1, 0, y - 1)
    
        if total != n_samples:
            raise Exception("shape mismatch")
    
        fig = plt.figure(figsize=(32, 16), dpi=dpi)
        for idx in np.arange(n_samples):
            ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])
    
            plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
            ax.set_title(
                f"{complexities[idx]:.4f}",
                fontdict={"fontsize": 6, "color": "orange"},
                y=0.35,
            )
    
        fig.patch.set_facecolor("#292929")
        height_px = y * rows
        width_px = x * cols
        fig.set_size_inches(width_px / (dpi / 2), height_px / (dpi / 2), forward=True)
        fig.tight_layout(pad=0)
    
        return fig
    
    
    def visualize_sort_mean(data_loader: DataLoader, model: VAE):
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        recon_masks = torch.zeros((400, 3, 64, 128))
        masks = torch.zeros((400, 1, 64, 64))
    
        distances = torch.zeros((400,))
        for i, (mask, _) in enumerate(data_loader, 0):
            distance, mask_recon_grid = distance_measure(model, mask)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            masks[i] = mask[0]
            recon_masks[i] = mask_recon_grid
    
            distances[i] = distance
    
        sort_idx = torch.argsort(distances)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        recon_masks_sorted = recon_masks.numpy()[sort_idx]
    
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.xlabel("images")
        plt.ylabel("latent mean L2 distance")
    
        plt.savefig("shape_complexity/results/distance_plot.png")
    
    
        return (
            plot_samples(masks_sorted, distances.numpy()[sort_idx]),
            plot_samples(recon_masks_sorted, distances.numpy()[sort_idx]),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def visualize_sort_compression(data_loader: DataLoader):
        masks = torch.zeros((400, 1, 64, 64))
        distances = torch.zeros((400,))
        for i, (mask, _) in enumerate(data_loader, 0):
            masks[i] = mask[0]
            distances[i] = compression_complexity(mask)
    
        sort_idx = torch.argsort(distances)
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
        plt.xlabel("images")
        plt.ylabel("compression length")
        plt.savefig("shape_complexity/results/compression_plot.png")
    
        return plot_samples(masks_sorted, distances.numpy()[sort_idx])
    
    
    def visualize_sort_fft(data_loader: DataLoader):
        masks = torch.zeros((400, 1, 64, 64))
        distances = torch.zeros((400,))
        for i, (mask, _) in enumerate(data_loader, 0):
            masks[i] = mask[0]
            distances[i] = fft_measure(mask)
    
        sort_idx = torch.argsort(distances)
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
        plt.xlabel("images")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.ylabel("mean unidirectional frequency")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.savefig("shape_complexity/results/fft_plot.png")
    
        return plot_samples(masks_sorted, distances.numpy()[sort_idx])
    
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def visualize_sort_mean_precision(models: list[nn.Module], data_loader: DataLoader):
        masks = torch.zeros((400, 1, 64, 64))
        precisions = torch.zeros((400,))
    
        for i, (mask, _) in enumerate(data_loader, 0):
            masks[i] = mask[0]
            precisions[i] = mean_precision(models, mask)
    
        sort_idx = torch.argsort(precisions)
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(precisions)), np.sort(precisions.numpy()))
        plt.xlabel("images")
        plt.ylabel("mean precision")
        plt.savefig("shape_complexity/results/mean_prec_plot.png")
    
        return plot_samples(masks_sorted, precisions.numpy()[sort_idx])
    
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
        masks_recon = torch.zeros((400, 3, 64, 192))
        masks = torch.zeros((400, 1, 64, 64))
        diffs = torch.zeros((400,))
        for i, (mask, _) in enumerate(data_loader, 0):
            diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask)
            masks_recon[i] = mask_recon_grid
            masks[i] = mask[0]
            diffs[i] = diff
    
        sort_idx = np.argsort(np.array(diffs))
        recon_masks_sorted = masks_recon.numpy()[sort_idx]
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(diffs)), np.sort(diffs))
        plt.xlabel("images")
        plt.ylabel("pixelwise difference of reconstructions")
        plt.savefig("shape_complexity/results/px_diff_plot.png")
    
    
        return (
            plot_samples(masks_sorted, diffs[sort_idx]),
            plot_samples(recon_masks_sorted, diffs[sort_idx]),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        )
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def visualize_sort_3dim(
        data_loader: DataLoader, model_gb: nn.Module, model_lb: nn.Module
    ):
        masks_recon = torch.zeros((400, 3, 64, 192))
        masks = torch.zeros((400, 1, 64, 64))
        measures = torch.zeros((400, 3))
        for i, (mask, _) in enumerate(data_loader, 0):
            c_compress = compression_complexity(mask)
            c_fft = fft_measure(mask)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            # TODO: maybe exchange by diff or mean measure instead of precision
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
                model_gb, model_lb, mask
            )
            masks_recon[i] = mask_recon_grid
            masks[i] = mask[0]
            measures[i] = torch.tensor([c_compress, c_fft, c_vae])
    
        measures[:] /= measures.max(dim=0).values
        measure_norm = torch.linalg.vector_norm(measures, dim=1)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        fig = plt.figure()
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        fig.clf()
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        ax = fig.add_subplot(projection="3d")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        ax.scatter(measures[:, 0], measures[:, 1], measures[:, 2], marker="o")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
        ax.set_xlabel("zlib compression")
        ax.set_ylabel("FFT ratio")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        ax.set_zlabel(f"VAE ratio {model_gb.bottleneck}/{model_lb.bottleneck}")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.savefig("shape_complexity/results/3d_plot.png")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.close()
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        sort_idx = np.argsort(np.array(measure_norm))
        recon_masks_sorted = masks_recon.numpy()[sort_idx]
        masks_sorted = masks.numpy()[sort_idx]
    
    
        return (
            plot_samples(masks_sorted, measure_norm[sort_idx]),
            plot_samples(recon_masks_sorted, measure_norm[sort_idx]),
    
    def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
        sampler = RandomSampler(dataset, replacement=True, num_samples=400)
        data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
    
        masks = torch.zeros((400, 3, 64, 192))
        complexities = torch.zeros((400,))
        diffs = []
        for i, (mask, _) in enumerate(data_loader, 0):
            complexity, _, _, diff, mask_recon_grid = complexity_measure(
                model_gb, model_lb, mask, save_preliminary=True
            )
            masks[i] = mask_recon_grid
            diffs.append(diff)
            complexities[i] = complexity
    
        sort_idx = np.argsort(np.array(complexities))
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(diffs)), np.sort(diffs))
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.xlabel("images")
        plt.ylabel("prec difference (L-H)")
    
        plt.savefig("shape_complexity/results/diff_plot.png")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.clf()
    
    
        return plot_samples(masks_sorted, complexities[sort_idx])
    
    
    def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
        masks = torch.zeros((400, 3, 64, 192))
        complexities = torch.zeros((400,))
        complexities_lb = torch.zeros((400,))
        complexities_gb = torch.zeros((400,))
        diffs = []
        prec_lbs = []
        prec_gbs = []
        for i, (mask, _) in enumerate(data_loader, 0):
            (
                complexity,
                lb,
                gb,
                diff,
                prec_lb,
                prec_gb,
                mask_recon_grid,
            ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
            masks[i] = mask_recon_grid
            diffs.append(diff)
            prec_lbs.append(prec_lb)
            prec_gbs.append(prec_gb)
            complexities[i] = complexity
            complexities_lb[i] = lb
            complexities_gb[i] = gb
    
        sort_idx = np.argsort(np.array(complexities))
        sort_idx_lb = np.argsort(np.array(complexities_lb))
        sort_idx_gb = np.argsort(np.array(complexities_gb))
        masks_sorted = masks.numpy()[sort_idx]
        masks_sorted_lb = masks.numpy()[sort_idx_lb]
        masks_sorted_gb = masks.numpy()[sort_idx_gb]
    
        diff_sort_idx = np.argsort(diffs)
        # plt.savefig("shape_complexity/results/diff_plot.png")
        #    plt.clf
    
        fig, ax1 = plt.subplots()
        ax2 = ax1.twinx()
        ax1.plot(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            np.arange(len(prec_lbs)),
            np.array(prec_lbs)[diff_sort_idx],
            label=f"bottleneck {model_lb.bottleneck}",
        )
        ax1.plot(
            np.arange(len(prec_gbs)),
            np.array(prec_gbs)[diff_sort_idx],
            label=f"bottleneck {model_gb.bottleneck}",
        )
        ax2.plot(
            np.arange(len(diffs)),
            np.sort(diffs),
            color="red",
            label="prec difference (H - L)",
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        ax1.legend(loc="lower left")
        ax2.legend(loc="lower right")
        ax1.set_ylabel("precision")
        ax2.set_ylabel("prec difference (H-L)")
    
        plt.savefig("shape_complexity/results/prec_plot.png")
        plt.clf()
    
        fig = plot_samples(masks_sorted, complexities[sort_idx])
        fig.savefig("shape_complexity/results/abs.png")
        plt.close(fig)
        fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb])
        fig.savefig("shape_complexity/results/lb.png")
        plt.close(fig)
        fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb])
        fig.savefig("shape_complexity/results/gb.png")
        plt.close(fig)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
        recon_masks = torch.zeros((400, 3, 64, 192))
        masks = torch.zeros((400, 1, 64, 64))
        complexities = torch.zeros((400,))
        diffs = np.zeros((400,))
        prec_gbs = np.zeros((400,))
        prec_lbs = np.zeros((400,))
        for i, (mask, _) in enumerate(data_loader, 0):
            (
                complexity,
                _,
                _,
                diff,
                prec_lb,
                prec_gb,
                mask_recon_grid,
            ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
            recon_masks[i] = mask_recon_grid
            masks[i] = mask[0]
            diffs[i] = diff
            prec_gbs[i] = prec_gb
            prec_lbs[i] = prec_lb
            complexities[i] = complexity
    
        sort_idx = np.argsort(np.array(complexities))
        masks_sorted = masks.numpy()[sort_idx]
        recon_masks_sorted = recon_masks.numpy()[sort_idx]
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        # group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
        # bin_edges = [-np.inf, 0.0, 0.05, np.inf]
        # bins = np.digitize(diffs, bins=bin_edges, right=True)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        # for i in range(bins.min(), bins.max() + 1):
        #     bin_idx = bins == i
        #     binned_prec_gb = prec_gbs[bin_idx]
        #     prec_mean = binned_prec_gb.mean()
        #     prec_idx = prec_gbs > prec_mean
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        #     binned_masks_high = recon_masks[bin_idx & prec_idx]
        #     binned_masks_low = recon_masks[bin_idx & ~prec_idx]
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        #     save_image(
        #         binned_masks_high,
        #         f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
        #         padding=10,
        #     )
        #     save_image(
        #         binned_masks_low,
        #         f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
        #         padding=10,
        #     )
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        # diff_sort_idx = np.argsort(diffs)
    
        # fig, ax1 = plt.subplots()
        # ax2 = ax1.twinx()
        # ax1.plot(
        #     np.arange(len(prec_lbs)),
        #     np.array(prec_lbs)[diff_sort_idx],
        #     label=f"bottleneck {model_lb.bottleneck}",
        # )
        # ax1.plot(
        #     np.arange(len(prec_gbs)),
        #     np.array(prec_gbs)[diff_sort_idx],
        #     label=f"bottleneck {model_gb.bottleneck}",
        # )
        # ax2.plot(
        #     np.arange(len(diffs)),
        #     np.sort(diffs),
        #     color="red",
        #     label="prec difference (H - L)",
        # )
        # ax1.legend(loc="lower left")
        # ax2.legend(loc="lower right")
        # ax1.set_ylabel("precision")
        # ax2.set_ylabel("prec difference (H-L)")
        # ax1.set_xlabel("image")
        # plt.savefig("shape_complexity/results/prec_plot.png")
        # plt.tight_layout(pad=2)
        # plt.clf()
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
        fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
        fig.savefig("shape_complexity/results/abs_recon.png")
        plt.close(fig)
    
        fig = plot_samples(masks_sorted, complexities[sort_idx])
        fig.savefig("shape_complexity/results/abs.png")
        plt.close(fig)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    EPOCHS = 10
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    LOAD_PRETRAINED = True
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        bottlenecks = [4, 8, 16, 32]
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
    
        optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}
    
        data_loader, dataset = load_data()
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        train_size = int(0.8 * len(dataset))
        test_size = len(dataset) - train_size
        train_dataset, test_dataset = torch.utils.data.random_split(
            dataset, [train_size, test_size]
        )
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        if LOAD_PRETRAINED:
            for i, model in models.items():
                model.load_state_dict(
                    torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth")
                )
        else:
            for epoch in range(EPOCHS):
                for i, model in models.items():
                    train(
                        epoch,
                        model=model,
                        optimizer=optimizers[i],
                        data_loader=train_loader,
                    )
    
                test(epoch, models=list(models.values()), dataset=test_dataset)
    
            for bn in bottlenecks:
                if not os.path.exists("shape_complexity/trained"):
                    os.makedirs("shape_complexity/trained")
    
                torch.save(
                    models[bn].state_dict(),
                    f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth",
                )
    
        test(0, models=list(models.values()), dataset=test_dataset, save_results=True)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed