import os # from zlib import compress from bz2 import compress import matplotlib import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt import torch import torch.nn.functional as F from kornia.morphology import closing from PIL import Image from torch import Tensor, nn from torch.optim import Adam from torch.utils.data import DataLoader, RandomSampler from torchvision.datasets import ImageFolder from torchvision.transforms import transforms from torchvision.utils import make_grid, save_image device = torch.device("cuda") matplotlib.use("Agg") dx = [+1, 0, -1, 0] dy = [0, +1, 0, -1] # perform depth first search for each candidate/unlabeled region # reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int): n_rows, n_cols = mask.shape if x < 0 or x == n_rows: return if y < 0 or y == n_cols: return if labels[x][y] or not mask[x][y]: return # already labeled or not marked with 1 in image # mark the current cell labels[x][y] = current_label # recursively mark the neighbors for direction in range(4): dfs(mask, x + dx[direction], y + dy[direction], labels, current_label) def find_components(mask: npt.NDArray): label = 0 n_rows, n_cols = mask.shape labels = np.zeros(mask.shape, dtype=np.int8) for i in range(n_rows): for j in range(n_cols): if not labels[i][j] and mask[i][j]: label += 1 dfs(mask, i, j, labels, label) return labels # https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array def bbox(img): max_x, max_y = img.shape rows = np.any(img, axis=1) cols = np.any(img, axis=0) rmin, rmax = np.where(rows)[0][[0, -1]] cmin, cmax = np.where(cols)[0][[0, -1]] rmin = rmin - 1 if rmin > 0 else rmin cmin = cmin - 1 if cmin > 0 else cmin rmax = rmax + 1 if rmax < max_x else rmax cmax = cmax + 1 if cmax < max_y else cmax return rmin, rmax, cmin, cmax def extract_single_masks(labels: npt.NDArray): masks = [] for l in range(labels.max() + 1): mask = (labels == l).astype(np.int8) rmin, rmax, cmin, cmax = bbox(mask) masks.append(mask[rmin : rmax + 1, cmin : cmax + 1]) return masks class VAE(nn.Module): """ https://github.com/pytorch/examples/blob/main/vae/main.py """ def __init__(self, bottleneck=2, image_dim=4096): super(VAE, self).__init__() self.bottleneck = bottleneck self.image_dim = image_dim self.prelim_encode = nn.Sequential( nn.Flatten(), nn.Linear(image_dim, 400), nn.ReLU() ) self.encode_mu = nn.Sequential(nn.Linear(400, bottleneck)) self.encode_logvar = nn.Sequential(nn.Linear(400, bottleneck)) self.decode = nn.Sequential( nn.Linear(bottleneck, 400), nn.ReLU(), nn.Linear(400, image_dim), nn.Sigmoid(), ) def encode(self, x): # h1 = F.relu(self.encode(x)) # return self.encode_mu(h1), self.encode_logvar(h1) x = self.prelim_encode(x) return self.encode_mu(x), self.encode_logvar(x) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar # Reconstruction + KL divergence losses summed over all elements and batch def loss(self, recon_x, x, mu, logvar): BCE = F.binary_cross_entropy(recon_x, x.view(-1, 4096), reduction="sum") # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD class CONVVAE(nn.Module): def __init__( self, bottleneck=2, ): super(CONVVAE, self).__init__() self.bottleneck = bottleneck self.feature_dim = 6 * 6 * 64 self.conv1 = nn.Sequential( nn.Conv2d(1, 16, 5), nn.ReLU(), nn.MaxPool2d((2, 2)), # -> 30x30x16 ) self.conv2 = nn.Sequential( nn.Conv2d(16, 32, 3), nn.ReLU(), nn.MaxPool2d((2, 2)), # -> 14x14x32 ) self.conv3 = nn.Sequential( nn.Conv2d(32, 64, 3), nn.ReLU(), nn.MaxPool2d((2, 2)), # -> 6x6x64 ) # self.conv4 = nn.Sequential( # nn.Conv2d(32, self.bottleneck, 5), # nn.ReLU(), # nn.MaxPool2d((2, 2), return_indices=True), # -> 1x1xbottleneck # ) self.encode_mu = nn.Sequential( nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck), ) self.encode_logvar = nn.Sequential( nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck) ) self.decode_linear = nn.Linear(self.bottleneck, self.feature_dim) # self.decode4 = nn.Sequential( # nn.ConvTranspose2d(self.bottleneck, 32, 5), # nn.ReLU(), # ) self.decode3 = nn.Sequential( nn.ConvTranspose2d(64, 64, 2, stride=2), # -> 12x12x64 nn.ReLU(), nn.ConvTranspose2d(64, 32, 3), # -> 14x14x32 nn.ReLU(), ) self.decode2 = nn.Sequential( nn.ConvTranspose2d(32, 32, 2, stride=2), # -> 28x28x32 nn.ReLU(), nn.ConvTranspose2d(32, 16, 3), # -> 30x30x16 nn.ReLU(), ) self.decode1 = nn.Sequential( nn.ConvTranspose2d(16, 16, 2, stride=2), # -> 60x60x16 nn.ReLU(), nn.ConvTranspose2d(16, 1, 5), # -> 64x64x1 nn.Sigmoid(), ) def encode(self, x): x = self.conv1(x) x = self.conv2(x) x = self.conv3(x) # x, idx4 = self.conv4(x) mu = self.encode_mu(x) logvar = self.encode_logvar(x) return mu, logvar def decode(self, z: Tensor): z = self.decode_linear(z) z = z.view((-1, 64, 6, 6)) # z = F.max_unpool2d(z, idx4, (2, 2)) # z = self.decode4(z) z = self.decode3(z) z = self.decode2(z) z = self.decode1(z) # z = z.view(-1, 128, 1, 1) # return self.decode_conv(z) return z def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) return self.decode(z), mu, logvar def loss(self, recon_x, x, mu, logvar): """https://github.com/pytorch/examples/blob/main/vae/main.py""" BCE = F.binary_cross_entropy(recon_x, x, reduction="sum") # see Appendix B from VAE paper: # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 # https://arxiv.org/abs/1312.6114 # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) return BCE + KLD class CloseTransform: kernel = torch.ones(5, 5) def __init__(self, kernel=None): if kernel is not None: self.kernel = kernel def __call__(self, x): x = transforms.F.to_tensor(x) if len(x.shape) < 4: return transforms.F.to_pil_image( closing(x.unsqueeze(dim=0), self.kernel).squeeze(dim=0) ) return transforms.F.to_pil_image(closing(x, self.kernel)) def load_data(): transform = transforms.Compose( [ transforms.Grayscale(), transforms.RandomApply([CloseTransform()], p=0.25), transforms.Resize( (64, 64), interpolation=transforms.InterpolationMode.BILINEAR ), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), transforms.ToTensor(), ] ) trajectories = ( [] if False else [ # "v3_subtle_iceberg_lettuce_nymph-6_203-2056", "v3_absolute_grape_changeling-16_2277-4441", "v3_content_squash_angel-3_16074-17640", "v3_smooth_kale_loch_ness_monster-1_4439-6272", "v3_cute_breadfruit_spirit-6_17090-19102", "v3_key_nectarine_spirit-2_7081-9747", "v3_subtle_iceberg_lettuce_nymph-6_3819-6049", "v3_juvenile_apple_angel-30_396415-398113", "v3_subtle_iceberg_lettuce_nymph-6_6100-8068", ] ) datasets = [] for trj in trajectories: datasets.append( ImageFolder( f"activation_vis/out/critic/masks/{trj}/0/4", transform=transform ) ) datasets.append( ImageFolder("shape_complexity/data/simple_shapes", transform=transform) ) dataset = torch.utils.data.ConcatDataset(datasets) data_loader = DataLoader(dataset, batch_size=128, shuffle=True) return data_loader, dataset def train(epoch, model: VAE or CONVVAE, optimizer, data_loader, log_interval=40): model.train() train_loss = 0 for batch_idx, (data, _) in enumerate(data_loader): data = data.to(device) optimizer.zero_grad() recon_batch, mu, logvar = model(data) loss = model.loss(recon_batch, data, mu, logvar) loss.backward() train_loss += loss.item() optimizer.step() if batch_idx % log_interval == 0: print( "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( epoch, batch_idx * len(data), len(data_loader.dataset), 100.0 * batch_idx / len(data_loader), loss.item() / len(data), ) ) print( "====> Epoch: {} Average loss: {:.4f}".format( epoch, train_loss / len(data_loader.dataset) ) ) def test(epoch, models: list[CONVVAE] or list[VAE], dataset, save_results=False): for model in models: model.eval() test_loss = [0 for _ in models] test_batch_size = 32 sampler = RandomSampler(dataset, replacement=True, num_samples=64) test_loader = DataLoader(dataset, batch_size=test_batch_size, sampler=sampler) comp_data = None with torch.no_grad(): for i, (data, _) in enumerate(test_loader): data = data.to(device) for j, model in enumerate(models): recon_batch, mu, logvar = model(data) test_loss[j] += model.loss(recon_batch, data, mu, logvar).item() if i == 0: n = min(data.size(0), 20) if comp_data == None: comp_data = data[:n] comp_data = torch.cat( [comp_data, recon_batch.view(test_batch_size, 1, 64, 64)[:n]] ) if i == 0 and save_results: if not os.path.exists("results"): os.makedirs("results") save_image( comp_data.cpu(), "results/reconstruction_" + str(epoch) + ".png", nrow=min(data.size(0), 20), ) for i, model in enumerate(models): test_loss[i] /= len(test_loader.dataset) print(f"====> Test set loss model {model.bottleneck}: {test_loss[i]:.4f}") if save_results: plt.figure() bns = [m.bottleneck for m in models] plt.plot(bns, test_loss) plt.xticks(bns) plt.savefig("shape_complexity/results/vae_graph.png") plt.clf() def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4): model.eval() image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path))) labels = find_components(image[0]) single_masks = extract_single_masks(labels) mask = transforms.F.to_tensor( transforms.F.resize( transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)), (64, 64), ) ) with torch.no_grad(): mask = mask.to(device) recon_x, _, _ = model(mask) recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon mask_bits = mask.cpu().numpy() > 0 TP = (mask_bits & recon_bits).sum() FP = (recon_bits & ~mask_bits).sum() FN = (mask_bits & ~recon_bits).sum() prec = TP / (TP + FP) rec = TP / (TP + FN) # loss = pixelwise_loss(recon_x, mask) comp_data = torch.cat( [mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)] ) # print(f"mask loss: {loss:.4f}") return prec, rec, comp_data def distance_measure(model: VAE, img: Tensor): model.eval() with torch.no_grad(): mask = img.to(device) recon, mean, _ = model(mask) # TODO: apply threshold here?! _, recon_mean, _ = model(recon) distance = torch.norm(mean - recon_mean, p=2) return ( distance, make_grid(torch.stack([mask[0], recon[0]]).cpu(), nrow=2, padding=0), ) def compression_complexity(img: Tensor): np_img = img[0].numpy() compressed = compress(np_img) return len(compressed) def fft_measure(img: Tensor): np_img = img[0][0].numpy() fft = np.fft.fft2(np_img) fft_abs = np.abs(fft) n = fft.shape[0] pos_f_idx = n // 2 df = np.fft.fftfreq(n=n) amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum() mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2)) # mean frequency in range 0 to 0.5 return mean_freq / 0.5 def complexity_measure( model_gb: nn.Module, model_lb: nn.Module, img: Tensor, epsilon=0.4, save_preliminary=False, ): model_gb.eval() model_lb.eval() with torch.no_grad(): mask = img.to(device) recon_gb, _, _ = model_gb(mask) recon_lb, _, _ = model_lb(mask) recon_bits_gb = recon_gb.view(-1, 64, 64).cpu() > epsilon recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon mask_bits = mask[0].cpu() > 0 if save_preliminary: save_image( torch.stack( [mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()] ).cpu(), f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png", ) save_image( torch.stack( [ (mask_bits & recon_bits_gb).float(), (recon_bits_gb & ~mask_bits).float(), (mask_bits & recon_bits_lb).float(), (recon_bits_lb & ~mask_bits).float(), ] ).cpu(), f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png", ) tp_gb = (mask_bits & recon_bits_gb).sum() fp_gb = (recon_bits_gb & ~mask_bits).sum() tp_lb = (mask_bits & recon_bits_lb).sum() fp_lb = (recon_bits_lb & ~mask_bits).sum() prec_gb = tp_gb / (tp_gb + fp_gb) prec_lb = tp_lb / (tp_lb + fp_lb) complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb)) complexity_lb = 1 - prec_lb complexity_gb = 1 - prec_gb # 1 - (0.4 - abs(0.4 - 0.7)) = 0.9 # 1 - 0.7 = 0.3 return ( complexity, complexity_lb, complexity_gb, prec_gb - prec_lb, prec_lb, prec_gb, make_grid( torch.stack( [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)] ).cpu(), nrow=3, padding=0, ), ) def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4): mask = img.to(device) mask_bits = mask[0].cpu() > 0 precisions = np.zeros(len(models)) for i, model in enumerate(models): recon_gb, _, _ = model(mask) recon_bits = recon_gb.view(-1, 64, 64).cpu() > epsilon tp = (mask_bits & recon_bits).sum() fp = (recon_bits & ~mask_bits).sum() prec = tp / (tp + fp) precisions[i] = prec return 1 - precisions.mean() def complexity_measure_diff( model_gb: nn.Module, model_lb: nn.Module, img: Tensor, ): model_gb.eval() model_lb.eval() with torch.no_grad(): mask = img.to(device) recon_gb, _, _ = model_gb(mask) recon_lb, _, _ = model_lb(mask) diff = torch.abs((recon_gb - recon_lb).cpu().sum()) return ( diff, make_grid( torch.stack( [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)] ).cpu(), nrow=3, padding=0, ), ) def plot_samples(masks: Tensor, complexities: npt.NDArray): dpi = 150 rows = cols = 20 total = rows * cols n_samples, _, y, x = masks.shape extent = (0, x - 1, 0, y - 1) if total != n_samples: raise Exception("shape mismatch") fig = plt.figure(figsize=(32, 16), dpi=dpi) for idx in np.arange(n_samples): ax = fig.add_subplot(rows, cols, idx + 1, xticks=[], yticks=[]) plt.imshow(masks[idx][0], cmap=plt.cm.gray, extent=extent) ax.set_title( f"{complexities[idx]:.4f}", fontdict={"fontsize": 6, "color": "orange"}, y=0.35, ) fig.patch.set_facecolor("#292929") height_px = y * rows width_px = x * cols fig.set_size_inches(width_px / (dpi / 2), height_px / (dpi / 2), forward=True) fig.tight_layout(pad=0) return fig def visualize_sort_mean(data_loader: DataLoader, model: VAE): recon_masks = torch.zeros((400, 3, 64, 128)) masks = torch.zeros((400, 1, 64, 64)) distances = torch.zeros((400,)) for i, (mask, _) in enumerate(data_loader, 0): distance, mask_recon_grid = distance_measure(model, mask) masks[i] = mask[0] recon_masks[i] = mask_recon_grid distances[i] = distance sort_idx = torch.argsort(distances) recon_masks_sorted = recon_masks.numpy()[sort_idx] masks_sorted = masks.numpy()[sort_idx] plt.plot(np.arange(len(distances)), np.sort(distances.numpy())) plt.xlabel("images") plt.ylabel("latent mean L2 distance") plt.savefig("shape_complexity/results/distance_plot.png") return ( plot_samples(masks_sorted, distances.numpy()[sort_idx]), plot_samples(recon_masks_sorted, distances.numpy()[sort_idx]), ) def visualize_sort_compression(data_loader: DataLoader): masks = torch.zeros((400, 1, 64, 64)) distances = torch.zeros((400,)) for i, (mask, _) in enumerate(data_loader, 0): masks[i] = mask[0] distances[i] = compression_complexity(mask) sort_idx = torch.argsort(distances) masks_sorted = masks.numpy()[sort_idx] plt.plot(np.arange(len(distances)), np.sort(distances.numpy())) plt.xlabel("images") plt.ylabel("compression length") plt.savefig("shape_complexity/results/compression_plot.png") return plot_samples(masks_sorted, distances.numpy()[sort_idx]) def visualize_sort_fft(data_loader: DataLoader): masks = torch.zeros((400, 1, 64, 64)) distances = torch.zeros((400,)) for i, (mask, _) in enumerate(data_loader, 0): masks[i] = mask[0] distances[i] = fft_measure(mask) sort_idx = torch.argsort(distances) masks_sorted = masks.numpy()[sort_idx] plt.plot(np.arange(len(distances)), np.sort(distances.numpy())) plt.xlabel("images") plt.ylabel("mean unidirectional frequency") plt.savefig("shape_complexity/results/fft_plot.png") return plot_samples(masks_sorted, distances.numpy()[sort_idx]) def visualize_sort_mean_precision(models: list[nn.Module], data_loader: DataLoader): masks = torch.zeros((400, 1, 64, 64)) precisions = torch.zeros((400,)) for i, (mask, _) in enumerate(data_loader, 0): masks[i] = mask[0] precisions[i] = mean_precision(models, mask) sort_idx = torch.argsort(precisions) masks_sorted = masks.numpy()[sort_idx] plt.plot(np.arange(len(precisions)), np.sort(precisions.numpy())) plt.xlabel("images") plt.ylabel("mean precision") plt.savefig("shape_complexity/results/mean_prec_plot.png") return plot_samples(masks_sorted, precisions.numpy()[sort_idx]) def visualize_sort_diff(data_loader, model_gb: nn.Module, model_lb: nn.Module): masks_recon = torch.zeros((400, 3, 64, 192)) masks = torch.zeros((400, 1, 64, 64)) diffs = torch.zeros((400,)) for i, (mask, _) in enumerate(data_loader, 0): diff, mask_recon_grid = complexity_measure_diff(model_gb, model_lb, mask) masks_recon[i] = mask_recon_grid masks[i] = mask[0] diffs[i] = diff sort_idx = np.argsort(np.array(diffs)) recon_masks_sorted = masks_recon.numpy()[sort_idx] masks_sorted = masks.numpy()[sort_idx] plt.plot(np.arange(len(diffs)), np.sort(diffs)) plt.xlabel("images") plt.ylabel("pixelwise difference of reconstructions") plt.savefig("shape_complexity/results/px_diff_plot.png") return ( plot_samples(masks_sorted, diffs[sort_idx]), plot_samples(recon_masks_sorted, diffs[sort_idx]), ) def visualize_sort_3dim( data_loader: DataLoader, model_gb: nn.Module, model_lb: nn.Module ): masks_recon = torch.zeros((400, 3, 64, 192)) masks = torch.zeros((400, 1, 64, 64)) measures = torch.zeros((400, 3)) for i, (mask, _) in enumerate(data_loader, 0): c_compress = compression_complexity(mask) c_fft = fft_measure(mask) # TODO: maybe exchange by diff or mean measure instead of precision c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure( model_gb, model_lb, mask ) masks_recon[i] = mask_recon_grid masks[i] = mask[0] measures[i] = torch.tensor([c_compress, c_fft, c_vae]) measures[:] /= measures.max(dim=0).values measure_norm = torch.linalg.vector_norm(measures, dim=1) fig = plt.figure() fig.clf() ax = fig.add_subplot(projection="3d") ax.scatter(measures[:, 0], measures[:, 1], measures[:, 2], marker="o") ax.set_xlabel("zlib compression") ax.set_ylabel("FFT ratio") ax.set_zlabel(f"VAE ratio {model_gb.bottleneck}/{model_lb.bottleneck}") plt.savefig("shape_complexity/results/3d_plot.png") plt.close() sort_idx = np.argsort(np.array(measure_norm)) recon_masks_sorted = masks_recon.numpy()[sort_idx] masks_sorted = masks.numpy()[sort_idx] return ( plot_samples(masks_sorted, measure_norm[sort_idx]), plot_samples(recon_masks_sorted, measure_norm[sort_idx]), ) def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module): sampler = RandomSampler(dataset, replacement=True, num_samples=400) data_loader = DataLoader(dataset, batch_size=1, sampler=sampler) masks = torch.zeros((400, 3, 64, 192)) complexities = torch.zeros((400,)) diffs = [] for i, (mask, _) in enumerate(data_loader, 0): complexity, _, _, diff, mask_recon_grid = complexity_measure( model_gb, model_lb, mask, save_preliminary=True ) masks[i] = mask_recon_grid diffs.append(diff) complexities[i] = complexity sort_idx = np.argsort(np.array(complexities)) masks_sorted = masks.numpy()[sort_idx] plt.plot(np.arange(len(diffs)), np.sort(diffs)) plt.xlabel("images") plt.ylabel("prec difference (L-H)") plt.savefig("shape_complexity/results/diff_plot.png") plt.clf() return plot_samples(masks_sorted, complexities[sort_idx]) def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module): masks = torch.zeros((400, 3, 64, 192)) complexities = torch.zeros((400,)) complexities_lb = torch.zeros((400,)) complexities_gb = torch.zeros((400,)) diffs = [] prec_lbs = [] prec_gbs = [] for i, (mask, _) in enumerate(data_loader, 0): ( complexity, lb, gb, diff, prec_lb, prec_gb, mask_recon_grid, ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True) masks[i] = mask_recon_grid diffs.append(diff) prec_lbs.append(prec_lb) prec_gbs.append(prec_gb) complexities[i] = complexity complexities_lb[i] = lb complexities_gb[i] = gb sort_idx = np.argsort(np.array(complexities)) sort_idx_lb = np.argsort(np.array(complexities_lb)) sort_idx_gb = np.argsort(np.array(complexities_gb)) masks_sorted = masks.numpy()[sort_idx] masks_sorted_lb = masks.numpy()[sort_idx_lb] masks_sorted_gb = masks.numpy()[sort_idx_gb] diff_sort_idx = np.argsort(diffs) # plt.savefig("shape_complexity/results/diff_plot.png") # plt.clf fig, ax1 = plt.subplots() ax2 = ax1.twinx() ax1.plot( np.arange(len(prec_lbs)), np.array(prec_lbs)[diff_sort_idx], label=f"bottleneck {model_lb.bottleneck}", ) ax1.plot( np.arange(len(prec_gbs)), np.array(prec_gbs)[diff_sort_idx], label=f"bottleneck {model_gb.bottleneck}", ) ax2.plot( np.arange(len(diffs)), np.sort(diffs), color="red", label="prec difference (H - L)", ) ax1.legend(loc="lower left") ax2.legend(loc="lower right") ax1.set_ylabel("precision") ax2.set_ylabel("prec difference (H-L)") plt.savefig("shape_complexity/results/prec_plot.png") plt.clf() fig = plot_samples(masks_sorted, complexities[sort_idx]) fig.savefig("shape_complexity/results/abs.png") plt.close(fig) fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb]) fig.savefig("shape_complexity/results/lb.png") plt.close(fig) fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb]) fig.savefig("shape_complexity/results/gb.png") plt.close(fig) def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module): recon_masks = torch.zeros((400, 3, 64, 192)) masks = torch.zeros((400, 1, 64, 64)) complexities = torch.zeros((400,)) diffs = np.zeros((400,)) prec_gbs = np.zeros((400,)) prec_lbs = np.zeros((400,)) for i, (mask, _) in enumerate(data_loader, 0): ( complexity, _, _, diff, prec_lb, prec_gb, mask_recon_grid, ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True) recon_masks[i] = mask_recon_grid masks[i] = mask[0] diffs[i] = diff prec_gbs[i] = prec_gb prec_lbs[i] = prec_lb complexities[i] = complexity sort_idx = np.argsort(np.array(complexities)) masks_sorted = masks.numpy()[sort_idx] recon_masks_sorted = recon_masks.numpy()[sort_idx] # group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"] # bin_edges = [-np.inf, 0.0, 0.05, np.inf] # bins = np.digitize(diffs, bins=bin_edges, right=True) # for i in range(bins.min(), bins.max() + 1): # bin_idx = bins == i # binned_prec_gb = prec_gbs[bin_idx] # prec_mean = binned_prec_gb.mean() # prec_idx = prec_gbs > prec_mean # binned_masks_high = recon_masks[bin_idx & prec_idx] # binned_masks_low = recon_masks[bin_idx & ~prec_idx] # save_image( # binned_masks_high, # f"shape_complexity/results/diff_{group_labels[i-1]}_high.png", # padding=10, # ) # save_image( # binned_masks_low, # f"shape_complexity/results/diff_{group_labels[i-1]}_low.png", # padding=10, # ) # diff_sort_idx = np.argsort(diffs) # fig, ax1 = plt.subplots() # ax2 = ax1.twinx() # ax1.plot( # np.arange(len(prec_lbs)), # np.array(prec_lbs)[diff_sort_idx], # label=f"bottleneck {model_lb.bottleneck}", # ) # ax1.plot( # np.arange(len(prec_gbs)), # np.array(prec_gbs)[diff_sort_idx], # label=f"bottleneck {model_gb.bottleneck}", # ) # ax2.plot( # np.arange(len(diffs)), # np.sort(diffs), # color="red", # label="prec difference (H - L)", # ) # ax1.legend(loc="lower left") # ax2.legend(loc="lower right") # ax1.set_ylabel("precision") # ax2.set_ylabel("prec difference (H-L)") # ax1.set_xlabel("image") # plt.savefig("shape_complexity/results/prec_plot.png") # plt.tight_layout(pad=2) # plt.clf() fig = plot_samples(recon_masks_sorted, complexities[sort_idx]) fig.savefig("shape_complexity/results/abs_recon.png") plt.close(fig) fig = plot_samples(masks_sorted, complexities[sort_idx]) fig.savefig("shape_complexity/results/abs.png") plt.close(fig) LR = 1e-3 EPOCHS = 10 LOAD_PRETRAINED = True def main(): bottlenecks = [4, 8, 16, 32] models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks} optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()} data_loader, dataset = load_data() train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size train_dataset, test_dataset = torch.utils.data.random_split( dataset, [train_size, test_size] ) train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True) if LOAD_PRETRAINED: for i, model in models.items(): model.load_state_dict( torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth") ) else: for epoch in range(EPOCHS): for i, model in models.items(): train( epoch, model=model, optimizer=optimizers[i], data_loader=train_loader, ) test(epoch, models=list(models.values()), dataset=test_dataset) for bn in bottlenecks: if not os.path.exists("shape_complexity/trained"): os.makedirs("shape_complexity/trained") torch.save( models[bn].state_dict(), f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth", ) test(0, models=list(models.values()), dataset=test_dataset, save_results=True) bn_gt = 32 bn_lt = 8 # for i in range(10): # figure = visualize_sort(dataset, models[bn_gt], models[bn_lt]) # figure.savefig( # f"shape_complexity/results/this_{bn_gt}_to_{bn_lt}_sample{i}.png" # ) # figure.clear() # plt.close(figure) # figure = visualize_sort(dataset, models[bn_gt], models[bn_lt]) # figure.savefig(f"shape_complexity/results/sort_{bn_gt}_to_{bn_lt}.png") sampler = RandomSampler(dataset, replacement=True, num_samples=400) data_loader = DataLoader(dataset, batch_size=1, sampler=sampler) visualize_sort_group(data_loader, models[bn_gt], models[bn_lt]) # visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt]) fig, fig_recon = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt]) fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png") fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png") plt.close(fig) plt.close(fig_recon) fig = visualize_sort_mean_precision(list(models.values()), data_loader) fig.savefig(f"shape_complexity/results/sort_mean_prec.png") plt.close(fig) fig = visualize_sort_fft(data_loader) fig.savefig(f"shape_complexity/results/sort_fft.png") plt.close(fig) fig = visualize_sort_compression(data_loader) fig.savefig(f"shape_complexity/results/sort_compression.png") fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt]) fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png") fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png") plt.close(fig) plt.close() # fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt]) # fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png") # fig_recon.savefig( # f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png" # ) # plt.close(fig) # plt.close(fig_recon) if __name__ == "__main__": main()