Skip to content
Snippets Groups Projects
shape_complexity.py 26.9 KiB
Newer Older
  • Learn to ignore specific revisions
  • Markus Rothgänger's avatar
    Markus Rothgänger committed
    import matplotlib
    
    import matplotlib.pyplot as plt
    import numpy as np
    import numpy.typing as npt
    import torch
    import torch.nn.functional as F
    from PIL import Image
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    from torch import Tensor, conv2d, nn
    
    from torch.optim import Adam
    from torch.utils.data import DataLoader, RandomSampler
    from torchvision.datasets import ImageFolder
    from torchvision.transforms import transforms
    from torchvision.utils import save_image, make_grid
    
    device = torch.device("cuda")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    matplotlib.use("Agg")
    
    
    dx = [+1, 0, -1, 0]
    dy = [0, +1, 0, -1]
    
    
    # perform depth first search for each candidate/unlabeled region
    # reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation
    def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int):
        n_rows, n_cols = mask.shape
        if x < 0 or x == n_rows:
            return
        if y < 0 or y == n_cols:
            return
        if labels[x][y] or not mask[x][y]:
            return  # already labeled or not marked with 1 in image
    
        # mark the current cell
        labels[x][y] = current_label
    
        # recursively mark the neighbors
        for direction in range(4):
            dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)
    
    
    def find_components(mask: npt.NDArray):
        label = 0
    
        n_rows, n_cols = mask.shape
        labels = np.zeros(mask.shape, dtype=np.int8)
    
        for i in range(n_rows):
            for j in range(n_cols):
                if not labels[i][j] and mask[i][j]:
                    label += 1
                    dfs(mask, i, j, labels, label)
    
        return labels
    
    
    # https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
    def bbox(img):
        max_x, max_y = img.shape
        rows = np.any(img, axis=1)
        cols = np.any(img, axis=0)
        rmin, rmax = np.where(rows)[0][[0, -1]]
        cmin, cmax = np.where(cols)[0][[0, -1]]
    
        rmin = rmin - 1 if rmin > 0 else rmin
        cmin = cmin - 1 if cmin > 0 else cmin
        rmax = rmax + 1 if rmax < max_x else rmax
        cmax = cmax + 1 if cmax < max_y else cmax
    
        return rmin, rmax, cmin, cmax
    
    
    def extract_single_masks(labels: npt.NDArray):
        masks = []
        for l in range(labels.max() + 1):
            mask = (labels == l).astype(np.int8)
            rmin, rmax, cmin, cmax = bbox(mask)
            masks.append(mask[rmin : rmax + 1, cmin : cmax + 1])
    
        return masks
    
    
    class VAE(nn.Module):
        """
        https://github.com/pytorch/examples/blob/main/vae/main.py
        """
    
        def __init__(self, bottleneck=2, image_dim=4096):
            super(VAE, self).__init__()
    
            self.bottleneck = bottleneck
            self.image_dim = image_dim
    
            self.prelim_encode = nn.Sequential(
                nn.Flatten(), nn.Linear(image_dim, 400), nn.ReLU()
            )
            self.encode_mu = nn.Sequential(nn.Linear(400, bottleneck))
            self.encode_logvar = nn.Sequential(nn.Linear(400, bottleneck))
    
            self.decode = nn.Sequential(
                nn.Linear(bottleneck, 400),
                nn.ReLU(),
                nn.Linear(400, image_dim),
                nn.Sigmoid(),
            )
    
        def encode(self, x):
            # h1 = F.relu(self.encode(x))
            # return self.encode_mu(h1), self.encode_logvar(h1)
            x = self.prelim_encode(x)
            return self.encode_mu(x), self.encode_logvar(x)
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
    
        def forward(self, x):
            mu, logvar = self.encode(x)
            z = self.reparameterize(mu, logvar)
            return self.decode(z), mu, logvar
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        # Reconstruction + KL divergence losses summed over all elements and batch
        def loss(self, recon_x, x, mu, logvar):
            BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            # see Appendix B from VAE paper:
            # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
            # https://arxiv.org/abs/1312.6114
            # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            return BCE + KLD
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    class CONVVAE(nn.Module):
        def __init__(
            self,
            bottleneck=2,
        ):
            super(CONVVAE, self).__init__()
    
            self.bottleneck = bottleneck
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.feature_dim = 32
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.conv1 = nn.Sequential(
                nn.Conv2d(1, 16, 5),
                nn.ReLU(),
                nn.MaxPool2d((2, 2), return_indices=True),  # -> 30x30x16
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(16, 32, 3),
                nn.ReLU(),
                nn.MaxPool2d((2, 2), return_indices=True),  # -> 14x14x32
            )
            self.conv3 = nn.Sequential(
                nn.Conv2d(32, 64, 3),
                nn.ReLU(),
                nn.MaxPool2d((2, 2), return_indices=True),  # -> 6x6x64
            )
            self.conv4 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Conv2d(64, 2 * self.bottleneck, 5),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ReLU(),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.MaxPool2d((2, 2), return_indices=True),  # -> 1x1x2*bottleneck
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.encode_mu = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Flatten(),
                nn.Linear(
                    2 * self.bottleneck, self.bottleneck
                ),  # TODO: maybe only FC from bn x bn
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.encode_logvar = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.Flatten(), nn.Linear(2 * self.bottleneck, self.bottleneck)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            self.decode_linear = nn.Linear(self.bottleneck, 2 * self.bottleneck)
    
            self.decode4 = nn.Sequential(
                nn.ConvTranspose2d(2 * self.bottleneck, 64, 5),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ReLU(),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.decode3 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ConvTranspose2d(64, 32, 3),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ReLU(),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.decode2 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ConvTranspose2d(32, 16, 3),
                nn.ReLU(),
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            )
            self.decode1 = nn.Sequential(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                nn.ConvTranspose2d(16, 1, 5),
                nn.Sigmoid(),
            )
    
        def encode(self, x):
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            x, idx1 = self.conv1(x)
            x, idx2 = self.conv2(x)
            x, idx3 = self.conv3(x)
            x, idx4 = self.conv4(x)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            mu = self.encode_mu(x)
            logvar = self.encode_logvar(x)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            return mu, logvar, (idx1, idx2, idx3, idx4)
    
        def decode(self, z: Tensor, indexes: tuple):
            (idx1, idx2, idx3, idx4) = indexes
            z = self.decode_linear(z)
            z = z.view((-1, 2 * self.bottleneck, 1, 1))
            z = F.max_unpool2d(z, idx4, (2, 2))
            z = self.decode4(z)
            z = F.max_unpool2d(z, idx3, (2, 2))
            z = self.decode3(z)
            z = F.max_unpool2d(z, idx2, (2, 2))
            z = self.decode2(z)
            z = F.max_unpool2d(z, idx1, (2, 2))
            z = self.decode1(z)
            # z = z.view(-1, 128, 1, 1)
            # return self.decode_conv(z)
            return z
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
        def reparameterize(self, mu, logvar):
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
    
        def forward(self, x):
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            mu, logvar, indexes = self.encode(x)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            z = self.reparameterize(mu, logvar)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            return self.decode(z, indexes), mu, logvar
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
        def loss(self, recon_x, x, mu, logvar):
            """https://github.com/pytorch/examples/blob/main/vae/main.py"""
            BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
    
            # see Appendix B from VAE paper:
            # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
            # https://arxiv.org/abs/1312.6114
            # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
            return BCE + KLD
    
    
    
    def load_data():
        transform = transforms.Compose(
            [
                transforms.Grayscale(),
                transforms.Resize(
                    (64, 64), interpolation=transforms.InterpolationMode.BILINEAR
                ),
                transforms.RandomHorizontalFlip(),
                transforms.RandomVerticalFlip(),
                transforms.ToTensor(),
            ]
        )
    
        trajectories = [
            # "v3_subtle_iceberg_lettuce_nymph-6_203-2056",
            "v3_absolute_grape_changeling-16_2277-4441",
            "v3_content_squash_angel-3_16074-17640",
            "v3_smooth_kale_loch_ness_monster-1_4439-6272",
            "v3_cute_breadfruit_spirit-6_17090-19102",
            "v3_key_nectarine_spirit-2_7081-9747",
            "v3_subtle_iceberg_lettuce_nymph-6_3819-6049",
            "v3_juvenile_apple_angel-30_396415-398113",
            "v3_subtle_iceberg_lettuce_nymph-6_6100-8068",
        ]
    
        datasets = []
        for trj in trajectories:
            datasets.append(
                ImageFolder(
                    f"activation_vis/out/critic/masks/{trj}/0/4", transform=transform
                )
            )
    
        dataset = torch.utils.data.ConcatDataset(datasets)
    
        data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
        return data_loader, dataset
    
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def train(epoch, model: VAE or CONVVAE, optimizer, data_loader, log_interval=40):
    
        model.train()
        train_loss = 0
        for batch_idx, (data, _) in enumerate(data_loader):
            data = data.to(device)
    
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            loss = model.loss(recon_batch, data, mu, logvar)
    
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
    
            if batch_idx % log_interval == 0:
                print(
                    "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                        epoch,
                        batch_idx * len(data),
                        len(data_loader.dataset),
                        100.0 * batch_idx / len(data_loader),
                        loss.item() / len(data),
                    )
                )
    
        print(
            "====> Epoch: {} Average loss: {:.4f}".format(
                epoch, train_loss / len(data_loader.dataset)
            )
        )
    
    
    def test(epoch, models, dataset):
        for model in models:
            model.eval()
        test_loss = [0 for _ in models]
    
        test_batch_size = 32
        sampler = RandomSampler(dataset, replacement=True, num_samples=64)
        test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=sampler)
        comp_data = None
    
        with torch.no_grad():
            for i, (data, _) in enumerate(test_loader):
                data = data.to(device)
    
                for j, model in enumerate(models):
                    recon_batch, mu, logvar = model(data)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                    test_loss[j] += model.loss(recon_batch, data, mu, logvar).item()
    
    
                    if i == 0:
                        n = min(data.size(0), 20)
                        if comp_data == None:
                            comp_data = data[:n]
                        comp_data = torch.cat(
                            [comp_data, recon_batch.view(test_batch_size, 1, 64, 64)[:n]]
                        )
    
                if i == 0:
                    if not os.path.exists("results"):
                        os.makedirs("results")
                    save_image(
                        comp_data.cpu(),
                        "results/reconstruction_" + str(epoch) + ".png",
                        nrow=min(data.size(0), 20),
                    )
    
        for i, model in enumerate(models):
            test_loss[i] /= len(test_loader.dataset)
            print(f"====> Test set loss model {i}: {test_loss[i]:.4f}")
    
    
    def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
        model.eval()
        image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path)))
        labels = find_components(image[0])
        single_masks = extract_single_masks(labels)
        mask = transforms.F.to_tensor(
            transforms.F.resize(
                transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)),
                (64, 64),
            )
        )
    
        with torch.no_grad():
            mask = mask.to(device)
            recon_x, _, _ = model(mask)
            recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon
            mask_bits = mask.cpu().numpy() > 0
    
            TP = (mask_bits & recon_bits).sum()
            FP = (recon_bits & ~mask_bits).sum()
            FN = (mask_bits & ~recon_bits).sum()
    
            prec = TP / (TP + FP)
            rec = TP / (TP + FN)
            # loss = pixelwise_loss(recon_x, mask)
            comp_data = torch.cat(
                [mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)]
            )
            # print(f"mask loss: {loss:.4f}")
    
            return prec, rec, comp_data
    
    
    def distance_measure(model: VAE, img: Tensor):
        model.eval()
    
        with torch.no_grad():
            mask = img.to(device)
            recon, mean, _ = model(mask)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            # TODO: apply threshold here?!
            _, recon_mean, _ = model(recon)
    
    
            distance = torch.norm(mean - recon_mean, p=2)
    
            return distance, make_grid(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0
    
            )
    
    
    def complexity_measure(
        model_gb: nn.Module,
        model_lb: nn.Module,
        img: Tensor,
        epsilon=0.4,
        save_preliminary=False,
    ):
        model_gb.eval()
        model_lb.eval()
    
        with torch.no_grad():
            mask = img.to(device)
            recon_gb, _, _ = model_gb(mask)
            recon_lb, _, _ = model_lb(mask)
    
            recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon
            recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
            mask_bits = mask[0].cpu() > 0
    
            if save_preliminary:
                save_image(
                    torch.stack(
                        [mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()]
                    ).cpu(),
                    f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
                )
                save_image(
                    torch.stack(
                        [
                            (mask_bits & recon_bits_gb).float(),
                            (recon_bits_gb & ~mask_bits).float(),
                            (mask_bits & recon_bits_lb).float(),
                            (recon_bits_lb & ~mask_bits).float(),
                        ]
                    ).cpu(),
                    f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
                )
    
            tp_gb = (mask_bits & recon_bits_gb).sum()
            fp_gb = (recon_bits_gb & ~mask_bits).sum()
            tp_lb = (mask_bits & recon_bits_lb).sum()
            fp_lb = (recon_bits_lb & ~mask_bits).sum()
    
            prec_gb = tp_gb / (tp_gb + fp_gb)
            prec_lb = tp_lb / (tp_lb + fp_lb)
            complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
            complexity_lb = 1 - prec_lb
            complexity_gb = 1 - prec_gb
            #           1 - (0.4 - abs(0.4 - 0.7))    = 0.9
            #           1 - 0.7                       = 0.3
    
            return (
                complexity,
                complexity_lb,
                complexity_gb,
                prec_gb - prec_lb,
                prec_lb,
                prec_gb,
                make_grid(
                    torch.stack(
                        [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                    ).cpu(),
                    nrow=3,
                    padding=0,
                ),
            )
    
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    def complexity_measure_diff(
        model_gb: nn.Module,
        model_lb: nn.Module,
        img: Tensor,
    ):
        model_gb.eval()
        model_lb.eval()
    
        with torch.no_grad():
            mask = img.to(device)
            recon_gb, _, _ = model_gb(mask)
            recon_lb, _, _ = model_lb(mask)
    
            diff = torch.abs((recon_gb - recon_lb).cpu().sum())
    
            return (
                diff,
                make_grid(
                    torch.stack(
                        [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
                    ).cpu(),
                    nrow=3,
                    padding=0,
                ),
            )
    
    
    
    def alt_complexity_measure(
        model_gb: nn.Module, model_lb: nn.Module, img: Tensor, epsilon=0.4
    ):
        model_gb.eval()
        model_lb.eval()
    
        with torch.no_grad():
            mask = img.to(device)
            recon_gb, _, _ = model_gb(mask)
            recon_lb, _, _ = model_lb(mask)
    
            bce_gb = F.binary_cross_entropy(recon_gb, mask.view(-1, 4096), reduction="sum")
            bce_lb = F.binary_cross_entropy(recon_lb, mask.view(-1, 4096), reduction="sum")
    
            recon_bits_gb = recon_gb.view(-1, 64, 64).cpu().numpy() > epsilon
            recon_bits_lb = recon_lb.view(-1, 64, 64).cpu().numpy() > epsilon
            mask_bits = mask.cpu().numpy() > 0
    
            tp_gb = (mask_bits & recon_bits_gb).sum()
            fp_gb = (recon_bits_gb & ~mask_bits).sum()
            tp_lb = (mask_bits & recon_bits_lb).sum()
            fp_lb = (recon_bits_lb & ~mask_bits).sum()
    
            prec_gb = tp_gb / (tp_gb + fp_gb)
            prec_lb = tp_lb / (tp_lb + fp_lb)
            complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
    
            return complexity
    
    
    def plot_samples(masks: Tensor, complexities: npt.NDArray):
        dpi = 150
        rows = cols = 20
        total = rows * cols
        n_samples, _, y, x = masks.shape
    
        extent = (0, x - 1, 0, y - 1)
    
        if total != n_samples:
            raise Exception("shape mismatch")
    
        fig = plt.figure(figsize=(32, 16), dpi=dpi)
        for idx in np.arange(n_samples):
            ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[])
    
            plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent)
            ax.set_title(
                f"{complexities[idx]:.4f}",
                fontdict={"fontsize": 6, "color": "orange"},
                y=0.35,
            )
    
        fig.patch.set_facecolor("#292929")
        height_px = y * rows
        width_px = x * cols
        fig.set_size_inches(width_px / (dpi / 2), height_px / (dpi / 2), forward=True)
        fig.tight_layout(pad=0)
    
        return fig
    
    
    def visualize_sort_mean(data_loader: DataLoader, model: VAE):
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        recon_masks = torch.zeros((400, 3, 64, 128))
        masks = torch.zeros((400, 1, 64, 64))
    
        distances = torch.zeros((400,))
        for i, (mask, _) in enumerate(data_loader, 0):
            distance, mask_recon_grid = distance_measure(model, mask)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            masks[i] = mask[0]
            recon_masks[i] = mask_recon_grid
    
            distances[i] = distance
    
        sort_idx = torch.argsort(distances)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        recon_masks_sorted = recon_masks.numpy()[sort_idx]
    
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(distances)), np.sort(distances.numpy()))
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.xlabel("images")
        plt.ylabel("latent mean L2 distance")
    
        plt.savefig("shape_complexity/results/distance_plot.png")
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        return plot_samples(masks_sorted, distances.numpy()[sort_idx]), plot_samples(
            recon_masks_sorted, distances.numpy()[sort_idx]
        )
    
    
    def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module):
        masks_recon = torch.zeros((400, 3, 64, 192))
        masks = torch.zeros((400, 1, 64, 64))
        diffs = torch.zeros((400,))
        for i, (mask, _) in enumerate(data_loader, 0):
            diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask)
            masks_recon[i] = mask_recon_grid
            masks[i] = mask[0]
            diffs[i] = diff
    
        sort_idx = np.argsort(np.array(diffs))
        recon_masks_sorted = masks_recon.numpy()[sort_idx]
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(diffs)), np.sort(diffs))
        plt.xlabel("images")
        plt.ylabel("pixelwise difference of reconstructions")
        plt.savefig("shape_complexity/results/px_diff_plot.png")
    
        return plot_samples(masks_sorted, diffs[sort_idx]), plot_samples(
            recon_masks_sorted, diffs[sort_idx]
        )
    
    
    
    def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
        sampler = RandomSampler(dataset, replacement=True, num_samples=400)
        data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
    
        masks = torch.zeros((400, 3, 64, 192))
        complexities = torch.zeros((400,))
        diffs = []
        for i, (mask, _) in enumerate(data_loader, 0):
            complexity, _, _, diff, mask_recon_grid = complexity_measure(
                model_gb, model_lb, mask, save_preliminary=True
            )
            masks[i] = mask_recon_grid
            diffs.append(diff)
            complexities[i] = complexity
    
        sort_idx = np.argsort(np.array(complexities))
        masks_sorted = masks.numpy()[sort_idx]
    
        plt.plot(np.arange(len(diffs)), np.sort(diffs))
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.xlabel("images")
        plt.ylabel("prec difference (L-H)")
    
        plt.savefig("shape_complexity/results/diff_plot.png")
    
        return plot_samples(masks_sorted, complexities[sort_idx])
    
    
    def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
        masks = torch.zeros((400, 3, 64, 192))
        complexities = torch.zeros((400,))
        complexities_lb = torch.zeros((400,))
        complexities_gb = torch.zeros((400,))
        diffs = []
        prec_lbs = []
        prec_gbs = []
        for i, (mask, _) in enumerate(data_loader, 0):
            (
                complexity,
                lb,
                gb,
                diff,
                prec_lb,
                prec_gb,
                mask_recon_grid,
            ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
            masks[i] = mask_recon_grid
            diffs.append(diff)
            prec_lbs.append(prec_lb)
            prec_gbs.append(prec_gb)
            complexities[i] = complexity
            complexities_lb[i] = lb
            complexities_gb[i] = gb
    
        sort_idx = np.argsort(np.array(complexities))
        sort_idx_lb = np.argsort(np.array(complexities_lb))
        sort_idx_gb = np.argsort(np.array(complexities_gb))
        masks_sorted = masks.numpy()[sort_idx]
        masks_sorted_lb = masks.numpy()[sort_idx_lb]
        masks_sorted_gb = masks.numpy()[sort_idx_gb]
    
        diff_sort_idx = np.argsort(diffs)
        # plt.savefig("shape_complexity/results/diff_plot.png")
        #    plt.clf
    
        fig, ax1 = plt.subplots()
        ax2 = ax1.twinx()
        ax1.plot(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
            np.arange(len(prec_lbs)),
            np.array(prec_lbs)[diff_sort_idx],
            label=f"bottleneck {model_lb.bottleneck}",
        )
        ax1.plot(
            np.arange(len(prec_gbs)),
            np.array(prec_gbs)[diff_sort_idx],
            label=f"bottleneck {model_gb.bottleneck}",
        )
        ax2.plot(
            np.arange(len(diffs)),
            np.sort(diffs),
            color="red",
            label="prec difference (H - L)",
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        ax1.legend(loc="lower left")
        ax2.legend(loc="lower right")
        ax1.set_ylabel("precision")
        ax2.set_ylabel("prec difference (H-L)")
    
        plt.savefig("shape_complexity/results/prec_plot.png")
        plt.clf()
    
        fig = plot_samples(masks_sorted, complexities[sort_idx])
        fig.savefig("shape_complexity/results/abs.png")
        plt.close(fig)
        fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb])
        fig.savefig("shape_complexity/results/lb.png")
        plt.close(fig)
        fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb])
        fig.savefig("shape_complexity/results/gb.png")
        plt.close(fig)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    
    def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
        recon_masks = torch.zeros((400, 3, 64, 192))
        masks = torch.zeros((400, 1, 64, 64))
        complexities = torch.zeros((400,))
        diffs = np.zeros((400,))
        prec_gbs = np.zeros((400,))
        prec_lbs = np.zeros((400,))
        for i, (mask, _) in enumerate(data_loader, 0):
            (
                complexity,
                _,
                _,
                diff,
                prec_lb,
                prec_gb,
                mask_recon_grid,
            ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
            recon_masks[i] = mask_recon_grid
            masks[i] = mask[0]
            diffs[i] = diff
            prec_gbs[i] = prec_gb
            prec_lbs[i] = prec_lb
            complexities[i] = complexity
    
        sort_idx = np.argsort(np.array(complexities))
        masks_sorted = masks.numpy()[sort_idx]
        recon_masks_sorted = recon_masks.numpy()[sort_idx]
    
        group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
        bin_edges = [-np.inf, 0.0, 0.05, np.inf]
        bins = np.digitize(diffs, bins=bin_edges, right=True)
    
        for i in range(bins.min(), bins.max() + 1):
            bin_idx = bins == i
            binned_prec_gb = prec_gbs[bin_idx]
            prec_mean = binned_prec_gb.mean()
            prec_idx = prec_gbs > prec_mean
    
            binned_masks_high = recon_masks[bin_idx & prec_idx]
            binned_masks_low = recon_masks[bin_idx & ~prec_idx]
    
            save_image(
                binned_masks_high,
                f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
                padding=10,
            )
            save_image(
                binned_masks_low,
                f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
                padding=10,
            )
    
        diff_sort_idx = np.argsort(diffs)
    
        fig, ax1 = plt.subplots()
        ax2 = ax1.twinx()
        ax1.plot(
            np.arange(len(prec_lbs)),
            np.array(prec_lbs)[diff_sort_idx],
            label=f"bottleneck {model_lb.bottleneck}",
        )
        ax1.plot(
            np.arange(len(prec_gbs)),
            np.array(prec_gbs)[diff_sort_idx],
            label=f"bottleneck {model_gb.bottleneck}",
        )
        ax2.plot(
            np.arange(len(diffs)),
            np.sort(diffs),
            color="red",
            label="prec difference (H - L)",
        )
        ax1.legend(loc="lower left")
        ax2.legend(loc="lower right")
        ax1.set_ylabel("precision")
        ax2.set_ylabel("prec difference (H-L)")
        ax1.set_xlabel("image")
        plt.savefig("shape_complexity/results/prec_plot.png")
        plt.tight_layout(pad=2)
        plt.clf()
    
        fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
        fig.savefig("shape_complexity/results/abs_recon.png")
        plt.close(fig)
    
        fig = plot_samples(masks_sorted, complexities[sort_idx])
        fig.savefig("shape_complexity/results/abs.png")
        plt.close(fig)
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
    LOAD_PRETRAINED = False
    
    
    
    def main():
        bottlenecks = [2, 4, 8, 16]
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
    
        optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}
    
        data_loader, dataset = load_data()
    
        if LOAD_PRETRAINED:
            for i, model in models.items():
                model.load_state_dict(
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                    torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth")
    
                )
        else:
            for epoch in range(EPOCHS):
                for i, model in models.items():
                    train(
                        epoch, model=model, optimizer=optimizers[i], data_loader=data_loader
                    )
    
                test(epoch, models=list(models.values()), dataset=dataset)
    
            for bn in bottlenecks:
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                if not os.path.exists("shape_complexity/trained"):
                    os.makedirs("shape_complexity/trained")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
                torch.save(
                    models[bn].state_dict(),
                    f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth",
                )
    
    
        bn_gt = 16
        bn_lt = 8
    
        # for i in range(10):
        #     figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
        #     figure.savefig(
        #         f"shape_complexity/results/this_{bn_gt}_to_{bn_lt}_sample{i}.png"
        #     )
        #     figure.clear()
        #     plt.close(figure)
    
        # figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
        # figure.savefig(f"shape_complexity/results/sort_{bn_gt}_to_{bn_lt}.png")
    
        sampler = RandomSampler(dataset, replacement=True, num_samples=400)
        data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
    
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        visualize_sort_group(data_loader, models[bn_gt], models[bn_lt])
        # visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
        fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
    
        fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
        plt.close(fig)
        plt.close(fig_recon)
        fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt])
        fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png")
        fig_recon.savefig(
            f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png"
        )
    
    Markus Rothgänger's avatar
    Markus Rothgänger committed
        plt.close(fig_recon)