From 293b9e2e50ec6d0b2f81ad870b460fb37c06bc57 Mon Sep 17 00:00:00 2001 From: markus rothgaenger <mrothgaenger@techfak.uni-bielefeld.de> Date: Thu, 9 Jun 2022 01:16:44 +0200 Subject: [PATCH] wip, rework --- shape_complexity/shape_complexity.py | 198 +++++++++++++------- shape_complexity/shape_complexity_VAE.ipynb | 4 +- shape_complexity/transform_dino_data.py | 114 +++++++++++ visualize_results/complexity.py | 0 visualize_results/data.py | 99 ++++++++++ visualize_results/main.py | 119 ++++++++++++ visualize_results/models.py | 113 +++++++++++ visualize_results/utils.py | 92 +++++++++ 8 files changed, 667 insertions(+), 72 deletions(-) create mode 100644 shape_complexity/transform_dino_data.py create mode 100644 visualize_results/complexity.py create mode 100644 visualize_results/data.py create mode 100644 visualize_results/main.py create mode 100644 visualize_results/models.py create mode 100644 visualize_results/utils.py diff --git a/shape_complexity/shape_complexity.py b/shape_complexity/shape_complexity.py index 93a3b01..e914793 100644 --- a/shape_complexity/shape_complexity.py +++ b/shape_complexity/shape_complexity.py @@ -62,7 +62,8 @@ def find_components(mask: npt.NDArray): # https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array -def bbox(img): +def bbox(img: Tensor): + img = img.numpy() max_x, max_y = img.shape rows = np.any(img, axis=1) cols = np.any(img, axis=0) @@ -249,6 +250,64 @@ class CONVVAE(nn.Module): return BCE + KLD +class BBoxTransform: + squared: bool = False + + def __init__(self, squared: bool = None) -> None: + if squared is not None: + self.squared = squared + + def __call__(self, img: any) -> Tensor: + img = transforms.F.to_tensor(img).squeeze(dim=0) + ymin, ymax, xmin, xmax = bbox(img) + if not self.squared: + return transforms.F.to_pil_image(img[ymin:ymax, xmin:xmax].unsqueeze(dim=0)) + + max_dim = (ymin, ymax) if ymax - ymin > xmax - xmin else (xmin, xmax) + n = max_dim[1] - max_dim[0] + if n % 2 != 0: + n += 1 + + n_med = np.round(n / 2) + + ymedian = np.round(ymin + (ymax - ymin) / 2) + xmedian = np.round(xmin + (xmax - xmin) / 2) + + M, N = img.shape + + ycutmin, ycutmax = int(ymedian - n_med if ymedian >= n_med else 0), int( + ymedian + n_med if ymedian + n_med <= M else M + ) + + xcutmin, xcutmax = int(xmedian - n_med if xmedian >= n_med else 0), int( + xmedian + n_med if xmedian + n_med <= N else N, + ) + + if (ycutmax - ycutmin) % 2 != 0: + ycutmin += 1 + if (xcutmax - xcutmin) % 2 != 0: + xcutmin += 1 + + squared_x = np.zeros((n, n)) + squared_cut_y = np.round((ycutmax - ycutmin) / 2) + squared_cut_x = np.round((xcutmax - xcutmin) / 2) + + dest_ymin, dest_ymax = int(n_med - squared_cut_y), int(n_med + squared_cut_y) + dest_xmin, dest_xmax = int(n_med - squared_cut_x), int(n_med + squared_cut_x) + + # print(ycutmin, ycutmax, ycutmax - ycutmin) + # print(dest_ymin, dest_ymax, squared_cut_y + squared_cut_y) + # print(xcutmin, xcutmax, xcutmax - xcutmin) + # print(dest_xmin, dest_xmax, squared_cut_x + squared_cut_x) + + squared_x[ + dest_ymin:dest_ymax, + dest_xmin:dest_xmax, + ] = img[ycutmin:ycutmax, xcutmin:xcutmax] + + return transforms.F.to_pil_image(torch.from_numpy(squared_x).unsqueeze(dim=0)) + + class CloseTransform: kernel = torch.ones(5, 5) @@ -311,8 +370,9 @@ def load_mpeg7_data(): [ transforms.Grayscale(), # transforms.RandomApply([CloseTransform()], p=0.25), + BBoxTransform(squared=True), transforms.Resize( - (64, 64), interpolation=transforms.InterpolationMode.BILINEAR + (64, 64), interpolation=transforms.InterpolationMode.NEAREST ), transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip(), @@ -326,6 +386,27 @@ def load_mpeg7_data(): return data_loader, dataset +def load_dino_data(): + transform = transforms.Compose( + [ + transforms.Grayscale(), + # transforms.RandomApply([CloseTransform()], p=0.25), + BBoxTransform(squared=True), + transforms.Resize( + (64, 64), interpolation=transforms.InterpolationMode.NEAREST + ), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ToTensor(), + ] + ) + + dataset = ImageFolder("shape_complexity/data/dino", transform=transform) + + data_loader = DataLoader(dataset, batch_size=128, shuffle=True) + return data_loader, dataset + + def load_data(): transform = transforms.Compose( [ @@ -453,39 +534,6 @@ def test(epoch, models: list[CONVVAE] or list[VAE], dataset, save_results=False) plt.clf() -def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4): - model.eval() - image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path))) - labels = find_components(image[0]) - single_masks = extract_single_masks(labels) - mask = transforms.F.to_tensor( - transforms.F.resize( - transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)), - (64, 64), - ) - ) - - with torch.no_grad(): - mask = mask.to(device) - recon_x, _, _ = model(mask) - recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon - mask_bits = mask.cpu().numpy() > 0 - - TP = (mask_bits & recon_bits).sum() - FP = (recon_bits & ~mask_bits).sum() - FN = (mask_bits & ~recon_bits).sum() - - prec = TP / (TP + FP) - rec = TP / (TP + FN) - # loss = pixelwise_loss(recon_x, mask) - comp_data = torch.cat( - [mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)] - ) - # print(f"mask loss: {loss:.4f}") - - return prec, rec, comp_data - - def l2_distance_measure(img: Tensor, model: VAE): model.eval() @@ -792,7 +840,6 @@ def visualize_sort_3dim( for i, (mask, _) in enumerate(data_loader, 0): c_compress, _ = compression_measure(mask, fill_ratio_norm=True) c_fft, _ = fft_measure(mask) - # TODO: maybe exchange by diff or mean measure instead of precision c_vae, mask_recon_grid = complexity_measure( mask, model_gb, model_lb, fill_ratio_norm=True ) @@ -825,11 +872,8 @@ def visualize_sort_3dim( LR = 1.5e-3 -EPOCHS = 100 -LOAD_PRETRAINED = True - -# TODO: might be a good idea to implement a bbox cut preprocessing transform thingy -# TODO: try 2dim rep (fft, comp) (fft, vae) (comp, vae) +EPOCHS = 80 +LOAD_PRETRAINED = False def main(): @@ -839,6 +883,7 @@ def main(): # data_loader, dataset = load_data() data_loader, dataset = load_mpeg7_data() + data_loader, dataset = load_dino_data() train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size @@ -894,6 +939,19 @@ def main(): ], max_norm=True, ) + visualize_sort_multidim( + data_loader, + [ + ( + "px_recon_comp", + pixelwise_complexity_measure, + [models[bn_gt], models[bn_lt], True], + ), + ("fft", fft_measure, []), + ("compression", compression_measure, [True]), + ], + max_norm=True, + ) visualize_sort_multidim( data_loader, [ @@ -927,35 +985,35 @@ def main(): max_norm=True, ) - # visualize_sort( - # data_loader, - # complexity_measure, - # "recon_complexity", - # models[bn_gt], - # models[bn_lt], - # fill_ratio_norm=True, - # ) - # visualize_sort( - # data_loader, - # pixelwise_complexity_measure, - # "px_recon_complexity", - # models[bn_gt], - # models[bn_lt], - # fill_ratio_norm=True, - # ) - - # visualize_sort(data_loader, mean_precision, "mean_precision", list(models.values())) - # visualize_sort(data_loader, fft_measure, "fft") - # visualize_sort(data_loader, compression_measure, "compression") - # visualize_sort( - # data_loader, - # compression_measure, - # "compression_fill_norm", - # fill_ratio_norm=True, - # ) - # visualize_sort( - # data_loader, l2_distance_measure, "latent_l2_distance", models[bn_gt] - # ) + visualize_sort( + data_loader, + complexity_measure, + "recon_complexity", + models[bn_gt], + models[bn_lt], + fill_ratio_norm=True, + ) + visualize_sort( + data_loader, + pixelwise_complexity_measure, + "px_recon_complexity", + models[bn_gt], + models[bn_lt], + fill_ratio_norm=True, + ) + + visualize_sort(data_loader, mean_precision, "mean_precision", list(models.values())) + visualize_sort(data_loader, fft_measure, "fft") + visualize_sort(data_loader, compression_measure, "compression") + visualize_sort( + data_loader, + compression_measure, + "compression_fill_norm", + fill_ratio_norm=True, + ) + visualize_sort( + data_loader, l2_distance_measure, "latent_l2_distance", models[bn_gt] + ) if __name__ == "__main__": diff --git a/shape_complexity/shape_complexity_VAE.ipynb b/shape_complexity/shape_complexity_VAE.ipynb index 266a55e..8b674c2 100644 --- a/shape_complexity/shape_complexity_VAE.ipynb +++ b/shape_complexity/shape_complexity_VAE.ipynb @@ -648,10 +648,10 @@ ], "metadata": { "interpreter": { - "hash": "84a83425d3adf844de89132bf0187dcd8fd164734ace6339df0fa8b55b713a8b" + "hash": "44317fb6075a6ba1e2c604c943b79dc9125a0abefce44942588ea31f355d03c2" }, "kernelspec": { - "display_name": "Python 3.9.9 ('minerl-indexing')", + "display_name": "Python 3.9.12 ('deeplearning')", "language": "python", "name": "python3" }, diff --git a/shape_complexity/transform_dino_data.py b/shape_complexity/transform_dino_data.py new file mode 100644 index 0000000..02ddf8b --- /dev/null +++ b/shape_complexity/transform_dino_data.py @@ -0,0 +1,114 @@ +from imageio import v3 as iio +import os +import numpy as np +import numpy.typing as npt +import sys + + +# reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation +# perform depth first search for each candidate/unlabeled region +dx = [+1, 0, -1, 0] +dy = [0, +1, 0, -1] + + +def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int): + n_rows, n_cols = mask.shape + if x < 0 or x == n_rows: + return + if y < 0 or y == n_cols: + return + if labels[x][y] or not mask[x][y]: + return # already labeled or not marked with 1 in image + + # mark the current cell + labels[x][y] = current_label + + # recursively mark the neighbors + for direction in range(4): + dfs(mask, x + dx[direction], y + dy[direction], labels, current_label) + + +def find_components(mask: npt.NDArray): + label = 0 + + n_rows, n_cols = mask.shape + labels = np.zeros(mask.shape, dtype=np.int8) + + for i in range(n_rows): + for j in range(n_cols): + if not labels[i][j] and mask[i][j]: + label += 1 + dfs(mask, i, j, labels, label) + + return labels + + +# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array +def bbox(img): + rows = np.any(img, axis=1) + cols = np.any(img, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + return rmin, rmax, cmin, cmax + + +def extract_single_masks(labels: npt.NDArray): + masks = [] + for l in range(1, labels.max() + 1): + # TODO: ignor label 0 ???? + mask = (labels == l).astype(np.int8) + max_x, max_y = mask.shape + pixel_sum = mask.sum() + if pixel_sum > 128 and pixel_sum < 512: + # print(pixel_sum) + rmin, rmax, cmin, cmax = bbox(mask) + rmin = rmin - 1 if rmin > 0 else rmin + cmin = cmin - 1 if cmin > 0 else cmin + rmax = rmax + 1 if rmax < max_x else rmax + cmax = cmax + 1 if cmax < max_y else cmax + + aspect_ratio = (rmax - rmin) / (cmax - cmin) + if aspect_ratio > 0.5 and aspect_ratio < 2.0: + masks.append(mask[rmin : rmax + 1, cmin : cmax + 1]) + else: + pass + # print("rejected due to aspect ratio") + + return masks + + +total = 0 +clusters = 5 + +sys.setrecursionlimit(1000000) + + +path = f"/home/markus/dev/dino/OutputDir" +output_path = f"/home/markus/uni/navigation_project/shape_complexity/data/dino" + +if not os.path.exists(output_path): + os.makedirs(output_path) + +for file in os.listdir(path): + fullpath = os.path.join(path, file) + if os.path.isdir(fullpath): + continue + + image = iio.imread(fullpath) + # image = image[0] + image = (image / (255 / (clusters - 1))).astype(np.int8) + + for i in range(clusters): + cluster_image = (image == i).astype(np.int8) + + labels = find_components(cluster_image) + single_masks = extract_single_masks(labels) + for l, mask in enumerate(single_masks): + total += 1 + iio.imwrite( + os.path.join(output_path, f"{file.replace('.jpg', '')}_{i}_{l}.png"), + (mask * 255).astype(np.uint8), + ) + +print(f"total masks: {total}") diff --git a/visualize_results/complexity.py b/visualize_results/complexity.py new file mode 100644 index 0000000..e69de29 diff --git a/visualize_results/data.py b/visualize_results/data.py new file mode 100644 index 0000000..dbc8be4 --- /dev/null +++ b/visualize_results/data.py @@ -0,0 +1,99 @@ +import numpy as np +import torch +from kornia.morphology import closing +from torch import Tensor +from torchvision import transforms + +from visualize_results.utils import bbox + + +class BBoxTransform: + squared: bool = False + + def __init__(self, squared: bool = None) -> None: + if squared is not None: + self.squared = squared + + def __call__(self, img: any) -> Tensor: + img = transforms.F.to_tensor(img).squeeze(dim=0) + ymin, ymax, xmin, xmax = bbox(img) + if not self.squared: + return transforms.F.to_pil_image(img[ymin:ymax, xmin:xmax].unsqueeze(dim=0)) + + max_dim = (ymin, ymax) if ymax - ymin > xmax - xmin else (xmin, xmax) + n = max_dim[1] - max_dim[0] + if n % 2 != 0: + n += 1 + + n_med = np.round(n / 2) + + ymedian = np.round(ymin + (ymax - ymin) / 2) + xmedian = np.round(xmin + (xmax - xmin) / 2) + + M, N = img.shape + + ycutmin, ycutmax = int(ymedian - n_med if ymedian >= n_med else 0), int( + ymedian + n_med if ymedian + n_med <= M else M + ) + + xcutmin, xcutmax = int(xmedian - n_med if xmedian >= n_med else 0), int( + xmedian + n_med if xmedian + n_med <= N else N, + ) + + if (ycutmax - ycutmin) % 2 != 0: + ycutmin += 1 + if (xcutmax - xcutmin) % 2 != 0: + xcutmin += 1 + + squared_x = np.zeros((n, n)) + squared_cut_y = np.round((ycutmax - ycutmin) / 2) + squared_cut_x = np.round((xcutmax - xcutmin) / 2) + + dest_ymin, dest_ymax = int(n_med - squared_cut_y), int(n_med + squared_cut_y) + dest_xmin, dest_xmax = int(n_med - squared_cut_x), int(n_med + squared_cut_x) + + # print(ycutmin, ycutmax, ycutmax - ycutmin) + # print(dest_ymin, dest_ymax, squared_cut_y + squared_cut_y) + # print(xcutmin, xcutmax, xcutmax - xcutmin) + # print(dest_xmin, dest_xmax, squared_cut_x + squared_cut_x) + + squared_x[ + dest_ymin:dest_ymax, + dest_xmin:dest_xmax, + ] = img[ycutmin:ycutmax, xcutmin:xcutmax] + + return transforms.F.to_pil_image(torch.from_numpy(squared_x).unsqueeze(dim=0)) + + +class CloseTransform: + kernel = torch.ones(5, 5) + + def __init__(self, kernel=None): + if kernel is not None: + self.kernel = kernel + + def __call__(self, x): + x = transforms.F.to_tensor(x) + + if len(x.shape) < 4: + return transforms.F.to_pil_image( + closing(x.unsqueeze(dim=0), self.kernel).squeeze(dim=0) + ) + + return transforms.F.to_pil_image(closing(x, self.kernel)) + + +def get_dino_transforms(): + return transforms.Compose( + [ + transforms.Grayscale(), + # transforms.RandomApply([CloseTransform()], p=0.25), + BBoxTransform(squared=True), + transforms.Resize( + (64, 64), interpolation=transforms.InterpolationMode.NEAREST + ), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ToTensor(), + ] + ) diff --git a/visualize_results/main.py b/visualize_results/main.py new file mode 100644 index 0000000..20a5aea --- /dev/null +++ b/visualize_results/main.py @@ -0,0 +1,119 @@ +import glob +import os +import sys +from typing import Generator +from venv import create + +import numpy as np +from PIL import Image as img +from PIL.Image import Image +from torchvision.transforms import transforms + +from visualize_results.data import get_dino_transforms +from visualize_results.utils import find_components, natsort + +sys.setrecursionlimit(1000000) +n_clusters = 5 + + +def create_vis(n_imgs, layer_range=range(12)) -> Generator[Image, Image, None]: + """ + yields three channel PIL image + + credit @wlad + """ + inp = "./OutputDir" + os.makedirs("kmeanimgs", exist_ok=True) + rltrenner = img.open("./data/rltrenner.png") + rltrenner = rltrenner.convert("RGB") + + yseperator = 20 + print("Creating images") + for idx, numimg in enumerate(range(n_imgs)): # TODO: add tqdm again.. + imagesvert = [] + name = "img" + str(numimg) + + aimg = img.open("OutputDir/" + name + ".png") + aimg = aimg.convert("RGB") + + for depth in layer_range: + imagesinline = [aimg] + for attention_type in ["q", "k", "v"]: + if attention_type is not "v": + imagesinline.append(rltrenner) + + templist = glob.glob( + os.path.join( + inp, f"{name}{attention_type}depth{str(depth)}head*.png" + ) + ).sort(key=natsort) + + for timg in templist: + timg = img.open(timg) + + timg = yield timg + yield + + timg = timg.convert("RGB") + + timg = timg.resize((480, 480), resample=img.NEAREST) + imagesinline.append(timg) + + # yield imagesinline + # imagesinline.insert(0, aimg) + + widths, heights = zip(*(i.size for i in imagesinline)) + + total_width = sum(widths) + max_height = max(heights) + + vertical_img = img.new("RGB", (total_width, max_height)) + + x_offset = 0 + for im in imagesinline: + vertical_img.paste(im, (x_offset, 0)) + x_offset += im.size[0] + imagesvert.append(vertical_img.convert("RGB")) + + widths, heights = zip(*(i.size for i in imagesvert)) + + total_height = sum(heights) + (len(layer_range) * yseperator) + max_width = max(widths) + + final_img = img.new("RGB", (max_width, total_height)) + + y_offset = 0 + for im in imagesvert: + final_img.paste(im, (0, y_offset)) + y_offset += im.size[1] + yseperator + + final_img.save(os.path.join("kmeanimgs/", "img" + str(idx) + ".png")) + + +def main(): + img_transformer = get_dino_transforms() + image_generator = create_vis() + + for img in image_generator: + np_img = np.array(img) + cluster_img = (np_img / (255 / (n_clusters - 1))).astype(np.int8) + + for i in range(n_clusters): + cluster_image = (cluster_img == i).astype(np.int8) + + labels = find_components(cluster_image) + for l in range(1, labels.max() + 1): + mask = (labels == l).astype(np.int8) + normalized_img = img_transformer(transforms.F.to_pil_image(mask)) + # TODO: calculate complexity for normalized image.. + complexity = 1.0 + np_img[labels == l] = complexity + + # TODO: return manipulated PIL img / modify `img` + image_generator.send(transforms.F.to_pil_image(np_img)) + + continue + + +if __name__ == "__main__": + main() diff --git a/visualize_results/models.py b/visualize_results/models.py new file mode 100644 index 0000000..54915a0 --- /dev/null +++ b/visualize_results/models.py @@ -0,0 +1,113 @@ +from torch import Tensor +import torch +import torch.nn as nn +from torch import functional as F + + +class CONVVAE(nn.Module): + def __init__( + self, + bottleneck=2, + ): + super(CONVVAE, self).__init__() + + self.bottleneck = bottleneck + self.feature_dim = 6 * 6 * 64 + + self.conv1 = nn.Sequential( + nn.Conv2d(1, 16, 5), + nn.ReLU(), + nn.MaxPool2d((2, 2)), # -> 30x30x16 + ) + self.conv2 = nn.Sequential( + nn.Conv2d(16, 32, 3), + nn.ReLU(), + nn.MaxPool2d((2, 2)), # -> 14x14x32 + ) + self.conv3 = nn.Sequential( + nn.Conv2d(32, 64, 3), + nn.ReLU(), + nn.MaxPool2d((2, 2)), # -> 6x6x64 + ) + # self.conv4 = nn.Sequential( + # nn.Conv2d(32, self.bottleneck, 5), + # nn.ReLU(), + # nn.MaxPool2d((2, 2), return_indices=True), # -> 1x1xbottleneck + # ) + + self.encode_mu = nn.Sequential( + nn.Flatten(), + nn.Linear(self.feature_dim, self.bottleneck), + ) + self.encode_logvar = nn.Sequential( + nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck) + ) + + self.decode_linear = nn.Linear(self.bottleneck, self.feature_dim) + + # self.decode4 = nn.Sequential( + # nn.ConvTranspose2d(self.bottleneck, 32, 5), + # nn.ReLU(), + # ) + self.decode3 = nn.Sequential( + nn.ConvTranspose2d(64, 64, 2, stride=2), # -> 12x12x64 + nn.ReLU(), + nn.ConvTranspose2d(64, 32, 3), # -> 14x14x32 + nn.ReLU(), + ) + self.decode2 = nn.Sequential( + nn.ConvTranspose2d(32, 32, 2, stride=2), # -> 28x28x32 + nn.ReLU(), + nn.ConvTranspose2d(32, 16, 3), # -> 30x30x16 + nn.ReLU(), + ) + self.decode1 = nn.Sequential( + nn.ConvTranspose2d(16, 16, 2, stride=2), # -> 60x60x16 + nn.ReLU(), + nn.ConvTranspose2d(16, 1, 5), # -> 64x64x1 + nn.Sigmoid(), + ) + + def encode(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + # x, idx4 = self.conv4(x) + mu = self.encode_mu(x) + logvar = self.encode_logvar(x) + + return mu, logvar + + def decode(self, z: Tensor): + z = self.decode_linear(z) + z = z.view((-1, 64, 6, 6)) + # z = F.max_unpool2d(z, idx4, (2, 2)) + # z = self.decode4(z) + z = self.decode3(z) + z = self.decode2(z) + z = self.decode1(z) + # z = z.view(-1, 128, 1, 1) + # return self.decode_conv(z) + return z + + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + + def forward(self, x): + mu, logvar = self.encode(x) + z = self.reparameterize(mu, logvar) + return self.decode(z), mu, logvar + + def loss(self, recon_x, x, mu, logvar): + """https://github.com/pytorch/examples/blob/main/vae/main.py""" + BCE = F.binary_cross_entropy(recon_x, x, reduction="sum") + + # see Appendix B from VAE paper: + # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014 + # https://arxiv.org/abs/1312.6114 + # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) + KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + + return BCE + KLD diff --git a/visualize_results/utils.py b/visualize_results/utils.py new file mode 100644 index 0000000..b5bb453 --- /dev/null +++ b/visualize_results/utils.py @@ -0,0 +1,92 @@ +from imageio import v3 as iio +import os +import numpy as np +import numpy.typing as npt +import sys +import re + + +def natsort(s, _nsre=re.compile("([0-9]+)")): + return [ + int(text) if text.isdigit() else text.lower() for text in _nsre.split(str(s)) + ] + + +# reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation +# perform depth first search for each candidate/unlabeled region +dx = [+1, 0, -1, 0] +dy = [0, +1, 0, -1] + + +def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int): + n_rows, n_cols = mask.shape + if x < 0 or x == n_rows: + return + if y < 0 or y == n_cols: + return + if labels[x][y] or not mask[x][y]: + return # already labeled or not marked with 1 in image + + # mark the current cell + labels[x][y] = current_label + + # recursively mark the neighbors + for direction in range(4): + dfs(mask, x + dx[direction], y + dy[direction], labels, current_label) + + +def find_components(mask: npt.NDArray): + label = 0 + + n_rows, n_cols = mask.shape + labels = np.zeros(mask.shape, dtype=np.int8) + + for i in range(n_rows): + for j in range(n_cols): + if not labels[i][j] and mask[i][j]: + label += 1 + dfs(mask, i, j, labels, label) + + return labels + + +# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array +def bbox(img): + img = img.numpy() + max_x, max_y = img.shape + rows = np.any(img, axis=1) + cols = np.any(img, axis=0) + rmin, rmax = np.where(rows)[0][[0, -1]] + cmin, cmax = np.where(cols)[0][[0, -1]] + + rmin = rmin - 1 if rmin > 0 else rmin + cmin = cmin - 1 if cmin > 0 else cmin + rmax = rmax + 1 if rmax < max_x else rmax + cmax = cmax + 1 if cmax < max_y else cmax + + return rmin, rmax, cmin, cmax + + +def extract_single_masks(labels: npt.NDArray): + masks = [] + for l in range(1, labels.max() + 1): + # TODO: ignore label 0 ???? + mask = (labels == l).astype(np.int8) + max_x, max_y = mask.shape + pixel_sum = mask.sum() + if pixel_sum > 128 and pixel_sum < 512: + # print(pixel_sum) + rmin, rmax, cmin, cmax = bbox(mask) + rmin = rmin - 1 if rmin > 0 else rmin + cmin = cmin - 1 if cmin > 0 else cmin + rmax = rmax + 1 if rmax < max_x else rmax + cmax = cmax + 1 if cmax < max_y else cmax + + aspect_ratio = (rmax - rmin) / (cmax - cmin) + if aspect_ratio > 0.5 and aspect_ratio < 2.0: + masks.append(mask[rmin : rmax + 1, cmin : cmax + 1]) + else: + pass + # print("rejected due to aspect ratio") + + return masks -- GitLab