diff --git a/shape_complexity/shape_complexity.py b/shape_complexity/shape_complexity.py
index f57f2a1d9df92e05627e72d969b5af42ce293464..135595b553d7466133f8ecfe74820fa75f429cde 100644
--- a/shape_complexity/shape_complexity.py
+++ b/shape_complexity/shape_complexity.py
@@ -2,9 +2,11 @@ import os
 
 # from zlib import compress
 from bz2 import compress
+import string
 
 import matplotlib
 import matplotlib.pyplot as plt
+from matplotlib.transforms import Transform
 import numpy as np
 import numpy.typing as npt
 import torch
@@ -13,8 +15,9 @@ from kornia.morphology import closing
 from PIL import Image
 from torch import Tensor, nn
 from torch.optim import Adam
-from torch.utils.data import DataLoader, RandomSampler
+from torch.utils.data import DataLoader, RandomSampler, Dataset
 from torchvision.datasets import ImageFolder
+from torchvision.io import read_image
 from torchvision.transforms import transforms
 from torchvision.utils import make_grid, save_image
 
@@ -265,6 +268,65 @@ class CloseTransform:
         return transforms.F.to_pil_image(closing(x, self.kernel))
 
 
+class MPEG7ShapeDataset(Dataset):
+    img_dir: str
+    filenames: list[str] = []
+    labels: list[str] = []
+    label_dict: dict[str]
+    transform: Transform = None
+
+    def __init__(self, img_dir, transform=None):
+        self.img_dir = img_dir
+        self.transform = transform
+
+        paths = os.listdir(self.img_dir)
+        labels = []
+        for file in paths:
+            fp = os.path.join(self.img_dir, file)
+            if os.path.isfile(fp):
+                label = file.split("-")[0]
+                self.filenames.append(fp)
+                labels.append(label)
+
+        label_name_dict = dict.fromkeys(labels)
+
+        self.label_dict = {i: v for (i, v) in enumerate(label_name_dict.keys())}
+        self.label_index_dict = {v: i for (i, v) in self.label_dict.items()}
+        self.labels = [self.label_index_dict[l] for l in labels]
+
+    def __len__(self):
+        return len(self.filenames)
+
+    def __getitem__(self, idx):
+        img_path = self.filenames[idx]
+        gif = Image.open(img_path)
+        gif.convert("RGB")
+        label = self.labels[idx]
+        if self.transform:
+            image = self.transform(gif)
+        return image, label
+
+
+def load_mpeg7_data():
+    transform = transforms.Compose(
+        [
+            transforms.Grayscale(),
+            # transforms.RandomApply([CloseTransform()], p=0.25),
+            transforms.Resize(
+                (64, 64), interpolation=transforms.InterpolationMode.BILINEAR
+            ),
+            transforms.RandomHorizontalFlip(),
+            transforms.RandomVerticalFlip(),
+            transforms.ToTensor(),
+        ]
+    )
+
+    dataset = MPEG7ShapeDataset("shape_complexity/data/mpeg7", transform)
+
+    data_loader = DataLoader(dataset, batch_size=128, shuffle=True)
+    return data_loader, dataset
+
+
 def load_data():
     transform = transforms.Compose(
         [
@@ -474,7 +536,6 @@ def complexity_measure(
     model_lb: nn.Module,
     img: Tensor,
     epsilon=0.4,
-    save_preliminary=False,
 ):
     model_gb.eval()
     model_lb.eval()
@@ -488,25 +549,6 @@ def complexity_measure(
         recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon
         mask_bits = mask[0].cpu() > 0
 
-        if save_preliminary:
-            save_image(
-                torch.stack(
-                    [mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()]
-                ).cpu(),
-                f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
-            )
-            save_image(
-                torch.stack(
-                    [
-                        (mask_bits & recon_bits_gb).float(),
-                        (recon_bits_gb & ~mask_bits).float(),
-                        (mask_bits & recon_bits_lb).float(),
-                        (recon_bits_lb & ~mask_bits).float(),
-                    ]
-                ).cpu(),
-                f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png",
-            )
-
         tp_gb = (mask_bits & recon_bits_gb).sum()
         fp_gb = (recon_bits_gb & ~mask_bits).sum()
         tp_lb = (mask_bits & recon_bits_lb).sum()
@@ -515,18 +557,9 @@ def complexity_measure(
         prec_gb = tp_gb / (tp_gb + fp_gb)
         prec_lb = tp_lb / (tp_lb + fp_lb)
         complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb))
-        complexity_lb = 1 - prec_lb
-        complexity_gb = 1 - prec_gb
-        #           1 - (0.4 - abs(0.4 - 0.7))    = 0.9
-        #           1 - 0.7                       = 0.3
 
         return (
             complexity,
-            complexity_lb,
-            complexity_gb,
-            prec_gb - prec_lb,
-            prec_lb,
-            prec_gb,
             make_grid(
                 torch.stack(
                     [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)]
@@ -730,9 +763,7 @@ def visualize_sort_3dim(
         c_compress = compression_complexity(mask)
         c_fft = fft_measure(mask)
         # TODO: maybe exchange by diff or mean measure instead of precision
-        c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
-            model_gb, model_lb, mask
-        )
+        c_vae, mask_recon_grid = complexity_measure(model_gb, model_lb, mask)
         masks_recon[i] = mask_recon_grid
         masks[i] = mask[0]
         measures[i] = torch.tensor([c_compress, c_fft, c_vae])
@@ -767,181 +798,55 @@ def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module):
 
     masks = torch.zeros((400, 3, 64, 192))
     complexities = torch.zeros((400,))
-    diffs = []
     for i, (mask, _) in enumerate(data_loader, 0):
-        complexity, _, _, diff, mask_recon_grid = complexity_measure(
+        complexity, mask_recon_grid = complexity_measure(
             model_gb, model_lb, mask, save_preliminary=True
         )
         masks[i] = mask_recon_grid
-        diffs.append(diff)
         complexities[i] = complexity
 
     sort_idx = np.argsort(np.array(complexities))
     masks_sorted = masks.numpy()[sort_idx]
 
-    plt.plot(np.arange(len(diffs)), np.sort(diffs))
-    plt.xlabel("images")
-    plt.ylabel("prec difference (L-H)")
-    plt.savefig("shape_complexity/results/diff_plot.png")
-    plt.clf()
-
     return plot_samples(masks_sorted, complexities[sort_idx])
 
 
 def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module):
     masks = torch.zeros((400, 3, 64, 192))
     complexities = torch.zeros((400,))
-    complexities_lb = torch.zeros((400,))
-    complexities_gb = torch.zeros((400,))
-    diffs = []
-    prec_lbs = []
-    prec_gbs = []
     for i, (mask, _) in enumerate(data_loader, 0):
         (
             complexity,
-            lb,
-            gb,
-            diff,
-            prec_lb,
-            prec_gb,
             mask_recon_grid,
         ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
         masks[i] = mask_recon_grid
-        diffs.append(diff)
-        prec_lbs.append(prec_lb)
-        prec_gbs.append(prec_gb)
         complexities[i] = complexity
-        complexities_lb[i] = lb
-        complexities_gb[i] = gb
 
     sort_idx = np.argsort(np.array(complexities))
-    sort_idx_lb = np.argsort(np.array(complexities_lb))
-    sort_idx_gb = np.argsort(np.array(complexities_gb))
     masks_sorted = masks.numpy()[sort_idx]
-    masks_sorted_lb = masks.numpy()[sort_idx_lb]
-    masks_sorted_gb = masks.numpy()[sort_idx_gb]
-
-    diff_sort_idx = np.argsort(diffs)
-    # plt.savefig("shape_complexity/results/diff_plot.png")
-    #    plt.clf
-
-    fig, ax1 = plt.subplots()
-    ax2 = ax1.twinx()
-    ax1.plot(
-        np.arange(len(prec_lbs)),
-        np.array(prec_lbs)[diff_sort_idx],
-        label=f"bottleneck {model_lb.bottleneck}",
-    )
-    ax1.plot(
-        np.arange(len(prec_gbs)),
-        np.array(prec_gbs)[diff_sort_idx],
-        label=f"bottleneck {model_gb.bottleneck}",
-    )
-    ax2.plot(
-        np.arange(len(diffs)),
-        np.sort(diffs),
-        color="red",
-        label="prec difference (H - L)",
-    )
-    ax1.legend(loc="lower left")
-    ax2.legend(loc="lower right")
-    ax1.set_ylabel("precision")
-    ax2.set_ylabel("prec difference (H-L)")
-    plt.savefig("shape_complexity/results/prec_plot.png")
-    plt.clf()
 
     fig = plot_samples(masks_sorted, complexities[sort_idx])
     fig.savefig("shape_complexity/results/abs.png")
     plt.close(fig)
-    fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb])
-    fig.savefig("shape_complexity/results/lb.png")
-    plt.close(fig)
-    fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb])
-    fig.savefig("shape_complexity/results/gb.png")
-    plt.close(fig)
 
 
 def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
     recon_masks = torch.zeros((400, 3, 64, 192))
     masks = torch.zeros((400, 1, 64, 64))
     complexities = torch.zeros((400,))
-    diffs = np.zeros((400,))
-    prec_gbs = np.zeros((400,))
-    prec_lbs = np.zeros((400,))
     for i, (mask, _) in enumerate(data_loader, 0):
         (
             complexity,
-            _,
-            _,
-            diff,
-            prec_lb,
-            prec_gb,
             mask_recon_grid,
         ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True)
         recon_masks[i] = mask_recon_grid
         masks[i] = mask[0]
-        diffs[i] = diff
-        prec_gbs[i] = prec_gb
-        prec_lbs[i] = prec_lb
         complexities[i] = complexity
 
     sort_idx = np.argsort(np.array(complexities))
     masks_sorted = masks.numpy()[sort_idx]
     recon_masks_sorted = recon_masks.numpy()[sort_idx]
 
-    # group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"]
-    # bin_edges = [-np.inf, 0.0, 0.05, np.inf]
-    # bins = np.digitize(diffs, bins=bin_edges, right=True)
-
-    # for i in range(bins.min(), bins.max() + 1):
-    #     bin_idx = bins == i
-    #     binned_prec_gb = prec_gbs[bin_idx]
-    #     prec_mean = binned_prec_gb.mean()
-    #     prec_idx = prec_gbs > prec_mean
-
-    #     binned_masks_high = recon_masks[bin_idx & prec_idx]
-    #     binned_masks_low = recon_masks[bin_idx & ~prec_idx]
-
-    #     save_image(
-    #         binned_masks_high,
-    #         f"shape_complexity/results/diff_{group_labels[i-1]}_high.png",
-    #         padding=10,
-    #     )
-    #     save_image(
-    #         binned_masks_low,
-    #         f"shape_complexity/results/diff_{group_labels[i-1]}_low.png",
-    #         padding=10,
-    #     )
-
-    # diff_sort_idx = np.argsort(diffs)
-
-    # fig, ax1 = plt.subplots()
-    # ax2 = ax1.twinx()
-    # ax1.plot(
-    #     np.arange(len(prec_lbs)),
-    #     np.array(prec_lbs)[diff_sort_idx],
-    #     label=f"bottleneck {model_lb.bottleneck}",
-    # )
-    # ax1.plot(
-    #     np.arange(len(prec_gbs)),
-    #     np.array(prec_gbs)[diff_sort_idx],
-    #     label=f"bottleneck {model_gb.bottleneck}",
-    # )
-    # ax2.plot(
-    #     np.arange(len(diffs)),
-    #     np.sort(diffs),
-    #     color="red",
-    #     label="prec difference (H - L)",
-    # )
-    # ax1.legend(loc="lower left")
-    # ax2.legend(loc="lower right")
-    # ax1.set_ylabel("precision")
-    # ax2.set_ylabel("prec difference (H-L)")
-    # ax1.set_xlabel("image")
-    # plt.savefig("shape_complexity/results/prec_plot.png")
-    # plt.tight_layout(pad=2)
-    # plt.clf()
-
     fig = plot_samples(recon_masks_sorted, complexities[sort_idx])
     fig.savefig("shape_complexity/results/abs_recon.png")
     plt.close(fig)
@@ -951,9 +856,15 @@ def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
     plt.close(fig)
 
 
-LR = 1e-3
-EPOCHS = 10
-LOAD_PRETRAINED = True
+LR = 1.5e-3
+EPOCHS = 100
+LOAD_PRETRAINED = False
+
+
+# TODO: build pixel ratio normalization for large black areas
+#       -> ideally, this fixes the compression metric
+# TODO: try out pixelwise loss again (in 3d as well)
+# TODO: might be a good idea to implement a bbox cut preprocessing transform thingy
 
 
 def main():
@@ -961,7 +872,8 @@ def main():
     models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks}
     optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()}
 
-    data_loader, dataset = load_data()
+    # data_loader, dataset = load_data()
+    data_loader, dataset = load_mpeg7_data()
 
     train_size = int(0.8 * len(dataset))
     test_size = len(dataset) - train_size