diff --git a/shape_complexity/shape_complexity.py b/shape_complexity/shape_complexity.py
index 93a3b01979a08ee7d82ed422f9da034a7ccc61de..e9147938c24d4f05abef207a0abeb29369e6483b 100644
--- a/shape_complexity/shape_complexity.py
+++ b/shape_complexity/shape_complexity.py
@@ -62,7 +62,8 @@ def find_components(mask: npt.NDArray):
 
 
 # https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
-def bbox(img):
+def bbox(img: Tensor):
+    img = img.numpy()
     max_x, max_y = img.shape
     rows = np.any(img, axis=1)
     cols = np.any(img, axis=0)
@@ -249,6 +250,64 @@ class CONVVAE(nn.Module):
         return BCE + KLD
 
 
+class BBoxTransform:
+    squared: bool = False
+
+    def __init__(self, squared: bool = None) -> None:
+        if squared is not None:
+            self.squared = squared
+
+    def __call__(self, img: any) -> Tensor:
+        img = transforms.F.to_tensor(img).squeeze(dim=0)
+        ymin, ymax, xmin, xmax = bbox(img)
+        if not self.squared:
+            return transforms.F.to_pil_image(img[ymin:ymax, xmin:xmax].unsqueeze(dim=0))
+
+        max_dim = (ymin, ymax) if ymax - ymin > xmax - xmin else (xmin, xmax)
+        n = max_dim[1] - max_dim[0]
+        if n % 2 != 0:
+            n += 1
+
+        n_med = np.round(n / 2)
+
+        ymedian = np.round(ymin + (ymax - ymin) / 2)
+        xmedian = np.round(xmin + (xmax - xmin) / 2)
+
+        M, N = img.shape
+
+        ycutmin, ycutmax = int(ymedian - n_med if ymedian >= n_med else 0), int(
+            ymedian + n_med if ymedian + n_med <= M else M
+        )
+
+        xcutmin, xcutmax = int(xmedian - n_med if xmedian >= n_med else 0), int(
+            xmedian + n_med if xmedian + n_med <= N else N,
+        )
+
+        if (ycutmax - ycutmin) % 2 != 0:
+            ycutmin += 1
+        if (xcutmax - xcutmin) % 2 != 0:
+            xcutmin += 1
+
+        squared_x = np.zeros((n, n))
+        squared_cut_y = np.round((ycutmax - ycutmin) / 2)
+        squared_cut_x = np.round((xcutmax - xcutmin) / 2)
+
+        dest_ymin, dest_ymax = int(n_med - squared_cut_y), int(n_med + squared_cut_y)
+        dest_xmin, dest_xmax = int(n_med - squared_cut_x), int(n_med + squared_cut_x)
+
+        # print(ycutmin, ycutmax, ycutmax - ycutmin)
+        # print(dest_ymin, dest_ymax, squared_cut_y + squared_cut_y)
+        # print(xcutmin, xcutmax, xcutmax - xcutmin)
+        # print(dest_xmin, dest_xmax, squared_cut_x + squared_cut_x)
+
+        squared_x[
+            dest_ymin:dest_ymax,
+            dest_xmin:dest_xmax,
+        ] = img[ycutmin:ycutmax, xcutmin:xcutmax]
+
+        return transforms.F.to_pil_image(torch.from_numpy(squared_x).unsqueeze(dim=0))
+
+
 class CloseTransform:
     kernel = torch.ones(5, 5)
 
@@ -311,8 +370,9 @@ def load_mpeg7_data():
         [
             transforms.Grayscale(),
             # transforms.RandomApply([CloseTransform()], p=0.25),
+            BBoxTransform(squared=True),
             transforms.Resize(
-                (64, 64), interpolation=transforms.InterpolationMode.BILINEAR
+                (64, 64), interpolation=transforms.InterpolationMode.NEAREST
             ),
             transforms.RandomHorizontalFlip(),
             transforms.RandomVerticalFlip(),
@@ -326,6 +386,27 @@ def load_mpeg7_data():
     return data_loader, dataset
 
 
+def load_dino_data():
+    transform = transforms.Compose(
+        [
+            transforms.Grayscale(),
+            # transforms.RandomApply([CloseTransform()], p=0.25),
+            BBoxTransform(squared=True),
+            transforms.Resize(
+                (64, 64), interpolation=transforms.InterpolationMode.NEAREST
+            ),
+            transforms.RandomHorizontalFlip(),
+            transforms.RandomVerticalFlip(),
+            transforms.ToTensor(),
+        ]
+    )
+
+    dataset = ImageFolder("shape_complexity/data/dino", transform=transform)
+
+    data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
+    return data_loader, dataset
+
+
 def load_data():
     transform = transforms.Compose(
         [
@@ -453,39 +534,6 @@ def test(epoch, models: list[CONVVAE] or list[VAE], dataset, save_results=False)
         plt.clf()
 
 
-def test_mask(model: nn.Module, path: str, label: int, epsilon=0.4):
-    model.eval()
-    image = transforms.F.to_tensor(transforms.F.to_grayscale(Image.open(path)))
-    labels = find_components(image[0])
-    single_masks = extract_single_masks(labels)
-    mask = transforms.F.to_tensor(
-        transforms.F.resize(
-            transforms.F.to_pil_image((single_masks[label] * 255).astype(np.uint8)),
-            (64, 64),
-        )
-    )
-
-    with torch.no_grad():
-        mask = mask.to(device)
-        recon_x, _, _ = model(mask)
-        recon_bits = recon_x.view(64, 64).cpu().numpy() > epsilon
-        mask_bits = mask.cpu().numpy() > 0
-
-        TP = (mask_bits & recon_bits).sum()
-        FP = (recon_bits & ~mask_bits).sum()
-        FN = (mask_bits & ~recon_bits).sum()
-
-        prec = TP / (TP + FP)
-        rec = TP / (TP + FN)
-        # loss = pixelwise_loss(recon_x, mask)
-        comp_data = torch.cat(
-            [mask[0].cpu(), recon_x.view(64, 64).cpu(), torch.from_numpy(recon_bits)]
-        )
-        # print(f"mask loss: {loss:.4f}")
-
-        return prec, rec, comp_data
-
-
 def l2_distance_measure(img: Tensor, model: VAE):
     model.eval()
 
@@ -792,7 +840,6 @@ def visualize_sort_3dim(
     for i, (mask, _) in enumerate(data_loader, 0):
         c_compress, _ = compression_measure(mask, fill_ratio_norm=True)
         c_fft, _ = fft_measure(mask)
-        # TODO: maybe exchange by diff or mean measure instead of precision
         c_vae, mask_recon_grid = complexity_measure(
             mask, model_gb, model_lb, fill_ratio_norm=True
         )
@@ -825,11 +872,8 @@ def visualize_sort_3dim(
 
 
 LR = 1.5e-3
-EPOCHS = 100
-LOAD_PRETRAINED = True
-
-# TODO: might be a good idea to implement a bbox cut preprocessing transform thingy
-# TODO: try 2dim rep (fft, comp) (fft, vae) (comp, vae)
+EPOCHS = 80
+LOAD_PRETRAINED = False
 
 
 def main():
@@ -839,6 +883,7 @@ def main():
 
     # data_loader, dataset = load_data()
     data_loader, dataset = load_mpeg7_data()
+    data_loader, dataset = load_dino_data()
 
     train_size = int(0.8 * len(dataset))
     test_size = len(dataset) - train_size
@@ -894,6 +939,19 @@ def main():
         ],
         max_norm=True,
     )
+    visualize_sort_multidim(
+        data_loader,
+        [
+            (
+                "px_recon_comp",
+                pixelwise_complexity_measure,
+                [models[bn_gt], models[bn_lt], True],
+            ),
+            ("fft", fft_measure, []),
+            ("compression", compression_measure, [True]),
+        ],
+        max_norm=True,
+    )
     visualize_sort_multidim(
         data_loader,
         [
@@ -927,35 +985,35 @@ def main():
         max_norm=True,
     )
 
-    # visualize_sort(
-    #     data_loader,
-    #     complexity_measure,
-    #     "recon_complexity",
-    #     models[bn_gt],
-    #     models[bn_lt],
-    #     fill_ratio_norm=True,
-    # )
-    # visualize_sort(
-    #     data_loader,
-    #     pixelwise_complexity_measure,
-    #     "px_recon_complexity",
-    #     models[bn_gt],
-    #     models[bn_lt],
-    #     fill_ratio_norm=True,
-    # )
-
-    # visualize_sort(data_loader, mean_precision, "mean_precision", list(models.values()))
-    # visualize_sort(data_loader, fft_measure, "fft")
-    # visualize_sort(data_loader, compression_measure, "compression")
-    # visualize_sort(
-    #     data_loader,
-    #     compression_measure,
-    #     "compression_fill_norm",
-    #     fill_ratio_norm=True,
-    # )
-    # visualize_sort(
-    #     data_loader, l2_distance_measure, "latent_l2_distance", models[bn_gt]
-    # )
+    visualize_sort(
+        data_loader,
+        complexity_measure,
+        "recon_complexity",
+        models[bn_gt],
+        models[bn_lt],
+        fill_ratio_norm=True,
+    )
+    visualize_sort(
+        data_loader,
+        pixelwise_complexity_measure,
+        "px_recon_complexity",
+        models[bn_gt],
+        models[bn_lt],
+        fill_ratio_norm=True,
+    )
+
+    visualize_sort(data_loader, mean_precision, "mean_precision", list(models.values()))
+    visualize_sort(data_loader, fft_measure, "fft")
+    visualize_sort(data_loader, compression_measure, "compression")
+    visualize_sort(
+        data_loader,
+        compression_measure,
+        "compression_fill_norm",
+        fill_ratio_norm=True,
+    )
+    visualize_sort(
+        data_loader, l2_distance_measure, "latent_l2_distance", models[bn_gt]
+    )
 
 
 if __name__ == "__main__":
diff --git a/shape_complexity/shape_complexity_VAE.ipynb b/shape_complexity/shape_complexity_VAE.ipynb
index 266a55ef0c30902aaa2f83386bd71756ba0b4744..8b674c281a4c65d9134f0f39cc9cef52f1efcbef 100644
--- a/shape_complexity/shape_complexity_VAE.ipynb
+++ b/shape_complexity/shape_complexity_VAE.ipynb
@@ -648,10 +648,10 @@
  ],
  "metadata": {
   "interpreter": {
-   "hash": "84a83425d3adf844de89132bf0187dcd8fd164734ace6339df0fa8b55b713a8b"
+   "hash": "44317fb6075a6ba1e2c604c943b79dc9125a0abefce44942588ea31f355d03c2"
   },
   "kernelspec": {
-   "display_name": "Python 3.9.9 ('minerl-indexing')",
+   "display_name": "Python 3.9.12 ('deeplearning')",
    "language": "python",
    "name": "python3"
   },
diff --git a/shape_complexity/transform_dino_data.py b/shape_complexity/transform_dino_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..02ddf8bd6af28ed93c506564ab5470e54bfd6991
--- /dev/null
+++ b/shape_complexity/transform_dino_data.py
@@ -0,0 +1,114 @@
+from imageio import v3 as iio
+import os
+import numpy as np
+import numpy.typing as npt
+import sys
+
+
+# reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation
+# perform depth first search for each candidate/unlabeled region
+dx = [+1, 0, -1, 0]
+dy = [0, +1, 0, -1]
+
+
+def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int):
+    n_rows, n_cols = mask.shape
+    if x < 0 or x == n_rows:
+        return
+    if y < 0 or y == n_cols:
+        return
+    if labels[x][y] or not mask[x][y]:
+        return  # already labeled or not marked with 1 in image
+
+    # mark the current cell
+    labels[x][y] = current_label
+
+    # recursively mark the neighbors
+    for direction in range(4):
+        dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)
+
+
+def find_components(mask: npt.NDArray):
+    label = 0
+
+    n_rows, n_cols = mask.shape
+    labels = np.zeros(mask.shape, dtype=np.int8)
+
+    for i in range(n_rows):
+        for j in range(n_cols):
+            if not labels[i][j] and mask[i][j]:
+                label += 1
+                dfs(mask, i, j, labels, label)
+
+    return labels
+
+
+# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
+def bbox(img):
+    rows = np.any(img, axis=1)
+    cols = np.any(img, axis=0)
+    rmin, rmax = np.where(rows)[0][[0, -1]]
+    cmin, cmax = np.where(cols)[0][[0, -1]]
+
+    return rmin, rmax, cmin, cmax
+
+
+def extract_single_masks(labels: npt.NDArray):
+    masks = []
+    for l in range(1, labels.max() + 1):
+        # TODO: ignor label 0 ????
+        mask = (labels == l).astype(np.int8)
+        max_x, max_y = mask.shape
+        pixel_sum = mask.sum()
+        if pixel_sum > 128 and pixel_sum < 512:
+            # print(pixel_sum)
+            rmin, rmax, cmin, cmax = bbox(mask)
+            rmin = rmin - 1 if rmin > 0 else rmin
+            cmin = cmin - 1 if cmin > 0 else cmin
+            rmax = rmax + 1 if rmax < max_x else rmax
+            cmax = cmax + 1 if cmax < max_y else cmax
+
+            aspect_ratio = (rmax - rmin) / (cmax - cmin)
+            if aspect_ratio > 0.5 and aspect_ratio < 2.0:
+                masks.append(mask[rmin : rmax + 1, cmin : cmax + 1])
+            else:
+                pass
+                # print("rejected due to aspect ratio")
+
+    return masks
+
+
+total = 0
+clusters = 5
+
+sys.setrecursionlimit(1000000)
+
+
+path = f"/home/markus/dev/dino/OutputDir"
+output_path = f"/home/markus/uni/navigation_project/shape_complexity/data/dino"
+
+if not os.path.exists(output_path):
+    os.makedirs(output_path)
+
+for file in os.listdir(path):
+    fullpath = os.path.join(path, file)
+    if os.path.isdir(fullpath):
+        continue
+
+    image = iio.imread(fullpath)
+    # image = image[0]
+    image = (image / (255 / (clusters - 1))).astype(np.int8)
+
+    for i in range(clusters):
+        cluster_image = (image == i).astype(np.int8)
+
+        labels = find_components(cluster_image)
+        single_masks = extract_single_masks(labels)
+        for l, mask in enumerate(single_masks):
+            total += 1
+            iio.imwrite(
+                os.path.join(output_path, f"{file.replace('.jpg', '')}_{i}_{l}.png"),
+                (mask * 255).astype(np.uint8),
+            )
+
+print(f"total masks: {total}")
diff --git a/visualize_results/complexity.py b/visualize_results/complexity.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/visualize_results/data.py b/visualize_results/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc8be4c532f52c2e1ca588884331d262bc5f0bf
--- /dev/null
+++ b/visualize_results/data.py
@@ -0,0 +1,99 @@
+import numpy as np
+import torch
+from kornia.morphology import closing
+from torch import Tensor
+from torchvision import transforms
+
+from visualize_results.utils import bbox
+
+
+class BBoxTransform:
+    squared: bool = False
+
+    def __init__(self, squared: bool = None) -> None:
+        if squared is not None:
+            self.squared = squared
+
+    def __call__(self, img: any) -> Tensor:
+        img = transforms.F.to_tensor(img).squeeze(dim=0)
+        ymin, ymax, xmin, xmax = bbox(img)
+        if not self.squared:
+            return transforms.F.to_pil_image(img[ymin:ymax, xmin:xmax].unsqueeze(dim=0))
+
+        max_dim = (ymin, ymax) if ymax - ymin > xmax - xmin else (xmin, xmax)
+        n = max_dim[1] - max_dim[0]
+        if n % 2 != 0:
+            n += 1
+
+        n_med = np.round(n / 2)
+
+        ymedian = np.round(ymin + (ymax - ymin) / 2)
+        xmedian = np.round(xmin + (xmax - xmin) / 2)
+
+        M, N = img.shape
+
+        ycutmin, ycutmax = int(ymedian - n_med if ymedian >= n_med else 0), int(
+            ymedian + n_med if ymedian + n_med <= M else M
+        )
+
+        xcutmin, xcutmax = int(xmedian - n_med if xmedian >= n_med else 0), int(
+            xmedian + n_med if xmedian + n_med <= N else N,
+        )
+
+        if (ycutmax - ycutmin) % 2 != 0:
+            ycutmin += 1
+        if (xcutmax - xcutmin) % 2 != 0:
+            xcutmin += 1
+
+        squared_x = np.zeros((n, n))
+        squared_cut_y = np.round((ycutmax - ycutmin) / 2)
+        squared_cut_x = np.round((xcutmax - xcutmin) / 2)
+
+        dest_ymin, dest_ymax = int(n_med - squared_cut_y), int(n_med + squared_cut_y)
+        dest_xmin, dest_xmax = int(n_med - squared_cut_x), int(n_med + squared_cut_x)
+
+        # print(ycutmin, ycutmax, ycutmax - ycutmin)
+        # print(dest_ymin, dest_ymax, squared_cut_y + squared_cut_y)
+        # print(xcutmin, xcutmax, xcutmax - xcutmin)
+        # print(dest_xmin, dest_xmax, squared_cut_x + squared_cut_x)
+
+        squared_x[
+            dest_ymin:dest_ymax,
+            dest_xmin:dest_xmax,
+        ] = img[ycutmin:ycutmax, xcutmin:xcutmax]
+
+        return transforms.F.to_pil_image(torch.from_numpy(squared_x).unsqueeze(dim=0))
+
+
+class CloseTransform:
+    kernel = torch.ones(5, 5)
+
+    def __init__(self, kernel=None):
+        if kernel is not None:
+            self.kernel = kernel
+
+    def __call__(self, x):
+        x = transforms.F.to_tensor(x)
+
+        if len(x.shape) < 4:
+            return transforms.F.to_pil_image(
+                closing(x.unsqueeze(dim=0), self.kernel).squeeze(dim=0)
+            )
+
+        return transforms.F.to_pil_image(closing(x, self.kernel))
+
+
+def get_dino_transforms():
+    return transforms.Compose(
+        [
+            transforms.Grayscale(),
+            # transforms.RandomApply([CloseTransform()], p=0.25),
+            BBoxTransform(squared=True),
+            transforms.Resize(
+                (64, 64), interpolation=transforms.InterpolationMode.NEAREST
+            ),
+            transforms.RandomHorizontalFlip(),
+            transforms.RandomVerticalFlip(),
+            transforms.ToTensor(),
+        ]
+    )
diff --git a/visualize_results/main.py b/visualize_results/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..20a5aea7d115eed518d7623c7bd97ec11d158801
--- /dev/null
+++ b/visualize_results/main.py
@@ -0,0 +1,119 @@
+import glob
+import os
+import sys
+from typing import Generator
+from venv import create
+
+import numpy as np
+from PIL import Image as img
+from PIL.Image import Image
+from torchvision.transforms import transforms
+
+from visualize_results.data import get_dino_transforms
+from visualize_results.utils import find_components, natsort
+
+sys.setrecursionlimit(1000000)
+n_clusters = 5
+
+
+def create_vis(n_imgs, layer_range=range(12)) -> Generator[Image, Image, None]:
+    """
+    yields three channel PIL image
+
+    credit @wlad
+    """
+    inp = "./OutputDir"
+    os.makedirs("kmeanimgs", exist_ok=True)
+    rltrenner = img.open("./data/rltrenner.png")
+    rltrenner = rltrenner.convert("RGB")
+
+    yseperator = 20
+    print("Creating images")
+    for idx, numimg in enumerate(range(n_imgs)):  # TODO: add tqdm again..
+        imagesvert = []
+        name = "img" + str(numimg)
+
+        aimg = img.open("OutputDir/" + name + ".png")
+        aimg = aimg.convert("RGB")
+
+        for depth in layer_range:
+            imagesinline = [aimg]
+            for attention_type in ["q", "k", "v"]:
+                if attention_type is not "v":
+                    imagesinline.append(rltrenner)
+
+                templist = glob.glob(
+                    os.path.join(
+                        inp, f"{name}{attention_type}depth{str(depth)}head*.png"
+                    )
+                ).sort(key=natsort)
+
+                for timg in templist:
+                    timg = img.open(timg)
+
+                    timg = yield timg
+                    yield
+
+                    timg = timg.convert("RGB")
+
+                    timg = timg.resize((480, 480), resample=img.NEAREST)
+                    imagesinline.append(timg)
+
+            # yield imagesinline
+            # imagesinline.insert(0, aimg)
+
+            widths, heights = zip(*(i.size for i in imagesinline))
+
+            total_width = sum(widths)
+            max_height = max(heights)
+
+            vertical_img = img.new("RGB", (total_width, max_height))
+
+            x_offset = 0
+            for im in imagesinline:
+                vertical_img.paste(im, (x_offset, 0))
+                x_offset += im.size[0]
+            imagesvert.append(vertical_img.convert("RGB"))
+
+        widths, heights = zip(*(i.size for i in imagesvert))
+
+        total_height = sum(heights) + (len(layer_range) * yseperator)
+        max_width = max(widths)
+
+        final_img = img.new("RGB", (max_width, total_height))
+
+        y_offset = 0
+        for im in imagesvert:
+            final_img.paste(im, (0, y_offset))
+            y_offset += im.size[1] + yseperator
+
+        final_img.save(os.path.join("kmeanimgs/", "img" + str(idx) + ".png"))
+
+
+def main():
+    img_transformer = get_dino_transforms()
+    image_generator = create_vis()
+
+    for img in image_generator:
+        np_img = np.array(img)
+        cluster_img = (np_img / (255 / (n_clusters - 1))).astype(np.int8)
+
+        for i in range(n_clusters):
+            cluster_image = (cluster_img == i).astype(np.int8)
+
+            labels = find_components(cluster_image)
+            for l in range(1, labels.max() + 1):
+                mask = (labels == l).astype(np.int8)
+                normalized_img = img_transformer(transforms.F.to_pil_image(mask))
+                # TODO: calculate complexity for normalized image..
+                complexity = 1.0
+                np_img[labels == l] = complexity
+
+        # TODO: return manipulated PIL img / modify `img`
+        image_generator.send(transforms.F.to_pil_image(np_img))
+
+        continue
+
+
+if __name__ == "__main__":
+    main()
diff --git a/visualize_results/models.py b/visualize_results/models.py
new file mode 100644
index 0000000000000000000000000000000000000000..54915a07f5b57823ab26dff1da455bb659f58425
--- /dev/null
+++ b/visualize_results/models.py
@@ -0,0 +1,113 @@
+from torch import Tensor
+import torch
+import torch.nn as nn
+from torch import functional as F
+
+
+class CONVVAE(nn.Module):
+    def __init__(
+        self,
+        bottleneck=2,
+    ):
+        super(CONVVAE, self).__init__()
+
+        self.bottleneck = bottleneck
+        self.feature_dim = 6 * 6 * 64
+
+        self.conv1 = nn.Sequential(
+            nn.Conv2d(1, 16, 5),
+            nn.ReLU(),
+            nn.MaxPool2d((2, 2)),  # -> 30x30x16
+        )
+        self.conv2 = nn.Sequential(
+            nn.Conv2d(16, 32, 3),
+            nn.ReLU(),
+            nn.MaxPool2d((2, 2)),  # -> 14x14x32
+        )
+        self.conv3 = nn.Sequential(
+            nn.Conv2d(32, 64, 3),
+            nn.ReLU(),
+            nn.MaxPool2d((2, 2)),  # -> 6x6x64
+        )
+        # self.conv4 = nn.Sequential(
+        #     nn.Conv2d(32, self.bottleneck, 5),
+        #     nn.ReLU(),
+        #     nn.MaxPool2d((2, 2), return_indices=True),  # -> 1x1xbottleneck
+        # )
+
+        self.encode_mu = nn.Sequential(
+            nn.Flatten(),
+            nn.Linear(self.feature_dim, self.bottleneck),
+        )
+        self.encode_logvar = nn.Sequential(
+            nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
+        )
+
+        self.decode_linear = nn.Linear(self.bottleneck, self.feature_dim)
+
+        # self.decode4 = nn.Sequential(
+        #     nn.ConvTranspose2d(self.bottleneck, 32, 5),
+        #     nn.ReLU(),
+        # )
+        self.decode3 = nn.Sequential(
+            nn.ConvTranspose2d(64, 64, 2, stride=2),  # -> 12x12x64
+            nn.ReLU(),
+            nn.ConvTranspose2d(64, 32, 3),  # -> 14x14x32
+            nn.ReLU(),
+        )
+        self.decode2 = nn.Sequential(
+            nn.ConvTranspose2d(32, 32, 2, stride=2),  # -> 28x28x32
+            nn.ReLU(),
+            nn.ConvTranspose2d(32, 16, 3),  # -> 30x30x16
+            nn.ReLU(),
+        )
+        self.decode1 = nn.Sequential(
+            nn.ConvTranspose2d(16, 16, 2, stride=2),  # -> 60x60x16
+            nn.ReLU(),
+            nn.ConvTranspose2d(16, 1, 5),  # -> 64x64x1
+            nn.Sigmoid(),
+        )
+
+    def encode(self, x):
+        x = self.conv1(x)
+        x = self.conv2(x)
+        x = self.conv3(x)
+        # x, idx4 = self.conv4(x)
+        mu = self.encode_mu(x)
+        logvar = self.encode_logvar(x)
+
+        return mu, logvar
+
+    def decode(self, z: Tensor):
+        z = self.decode_linear(z)
+        z = z.view((-1, 64, 6, 6))
+        # z = F.max_unpool2d(z, idx4, (2, 2))
+        # z = self.decode4(z)
+        z = self.decode3(z)
+        z = self.decode2(z)
+        z = self.decode1(z)
+        # z = z.view(-1, 128, 1, 1)
+        # return self.decode_conv(z)
+        return z
+
+    def reparameterize(self, mu, logvar):
+        std = torch.exp(0.5 * logvar)
+        eps = torch.randn_like(std)
+        return mu + eps * std
+
+    def forward(self, x):
+        mu, logvar = self.encode(x)
+        z = self.reparameterize(mu, logvar)
+        return self.decode(z), mu, logvar
+
+    def loss(self, recon_x, x, mu, logvar):
+        """https://github.com/pytorch/examples/blob/main/vae/main.py"""
+        BCE = F.binary_cross_entropy(recon_x, x, reduction="sum")
+
+        # see Appendix B from VAE paper:
+        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
+        # https://arxiv.org/abs/1312.6114
+        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
+        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
+
+        return BCE + KLD
diff --git a/visualize_results/utils.py b/visualize_results/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5bb4539d29f2e9774e1375fec1c952708f89eca
--- /dev/null
+++ b/visualize_results/utils.py
@@ -0,0 +1,92 @@
+from imageio import v3 as iio
+import os
+import numpy as np
+import numpy.typing as npt
+import sys
+import re
+
+
+def natsort(s, _nsre=re.compile("([0-9]+)")):
+    return [
+        int(text) if text.isdigit() else text.lower() for text in _nsre.split(str(s))
+    ]
+
+
+# reference: https://stackoverflow.com/questions/14465297/connected-component-labeling-implementation
+# perform depth first search for each candidate/unlabeled region
+dx = [+1, 0, -1, 0]
+dy = [0, +1, 0, -1]
+
+
+def dfs(mask: npt.NDArray, x: int, y: int, labels: npt.NDArray, current_label: int):
+    n_rows, n_cols = mask.shape
+    if x < 0 or x == n_rows:
+        return
+    if y < 0 or y == n_cols:
+        return
+    if labels[x][y] or not mask[x][y]:
+        return  # already labeled or not marked with 1 in image
+
+    # mark the current cell
+    labels[x][y] = current_label
+
+    # recursively mark the neighbors
+    for direction in range(4):
+        dfs(mask, x + dx[direction], y + dy[direction], labels, current_label)
+
+
+def find_components(mask: npt.NDArray):
+    label = 0
+
+    n_rows, n_cols = mask.shape
+    labels = np.zeros(mask.shape, dtype=np.int8)
+
+    for i in range(n_rows):
+        for j in range(n_cols):
+            if not labels[i][j] and mask[i][j]:
+                label += 1
+                dfs(mask, i, j, labels, label)
+
+    return labels
+
+
+# https://stackoverflow.com/questions/31400769/bounding-box-of-numpy-array
+def bbox(img):
+    img = img.numpy()
+    max_x, max_y = img.shape
+    rows = np.any(img, axis=1)
+    cols = np.any(img, axis=0)
+    rmin, rmax = np.where(rows)[0][[0, -1]]
+    cmin, cmax = np.where(cols)[0][[0, -1]]
+
+    rmin = rmin - 1 if rmin > 0 else rmin
+    cmin = cmin - 1 if cmin > 0 else cmin
+    rmax = rmax + 1 if rmax < max_x else rmax
+    cmax = cmax + 1 if cmax < max_y else cmax
+
+    return rmin, rmax, cmin, cmax
+
+
+def extract_single_masks(labels: npt.NDArray):
+    masks = []
+    for l in range(1, labels.max() + 1):
+        # TODO: ignore label 0 ????
+        mask = (labels == l).astype(np.int8)
+        max_x, max_y = mask.shape
+        pixel_sum = mask.sum()
+        if pixel_sum > 128 and pixel_sum < 512:
+            # print(pixel_sum)
+            rmin, rmax, cmin, cmax = bbox(mask)
+            rmin = rmin - 1 if rmin > 0 else rmin
+            cmin = cmin - 1 if cmin > 0 else cmin
+            rmax = rmax + 1 if rmax < max_x else rmax
+            cmax = cmax + 1 if cmax < max_y else cmax
+
+            aspect_ratio = (rmax - rmin) / (cmax - cmin)
+            if aspect_ratio > 0.5 and aspect_ratio < 2.0:
+                masks.append(mask[rmin : rmax + 1, cmin : cmax + 1])
+            else:
+                pass
+                # print("rejected due to aspect ratio")
+
+    return masks