diff --git a/shape_complexity/shape_complexity.py b/shape_complexity/shape_complexity.py index f57f2a1d9df92e05627e72d969b5af42ce293464..135595b553d7466133f8ecfe74820fa75f429cde 100644 --- a/shape_complexity/shape_complexity.py +++ b/shape_complexity/shape_complexity.py @@ -2,9 +2,11 @@ import os # from zlib import compress from bz2 import compress +import string import matplotlib import matplotlib.pyplot as plt +from matplotlib.transforms import Transform import numpy as np import numpy.typing as npt import torch @@ -13,8 +15,9 @@ from kornia.morphology import closing from PIL import Image from torch import Tensor, nn from torch.optim import Adam -from torch.utils.data import DataLoader, RandomSampler +from torch.utils.data import DataLoader, RandomSampler, Dataset from torchvision.datasets import ImageFolder +from torchvision.io import read_image from torchvision.transforms import transforms from torchvision.utils import make_grid, save_image @@ -265,6 +268,65 @@ class CloseTransform: return transforms.F.to_pil_image(closing(x, self.kernel)) +class MPEG7ShapeDataset(Dataset): + img_dir: str + filenames: list[str] = [] + labels: list[str] = [] + label_dict: dict[str] + transform: Transform = None + + def __init__(self, img_dir, transform=None): + self.img_dir = img_dir + self.transform = transform + + paths = os.listdir(self.img_dir) + labels = [] + for file in paths: + fp = os.path.join(self.img_dir, file) + if os.path.isfile(fp): + label = file.split("-")[0] + self.filenames.append(fp) + labels.append(label) + + label_name_dict = dict.fromkeys(labels) + + self.label_dict = {i: v for (i, v) in enumerate(label_name_dict.keys())} + self.label_index_dict = {v: i for (i, v) in self.label_dict.items()} + self.labels = [self.label_index_dict[l] for l in labels] + + def __len__(self): + return len(self.filenames) + + def __getitem__(self, idx): + img_path = self.filenames[idx] + gif = Image.open(img_path) + gif.convert("RGB") + label = self.labels[idx] + if self.transform: + image = self.transform(gif) + return image, label + + +def load_mpeg7_data(): + transform = transforms.Compose( + [ + transforms.Grayscale(), + # transforms.RandomApply([CloseTransform()], p=0.25), + transforms.Resize( + (64, 64), interpolation=transforms.InterpolationMode.BILINEAR + ), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ToTensor(), + ] + ) + + dataset = MPEG7ShapeDataset("shape_complexity/data/mpeg7", transform) + + data_loader = DataLoader(dataset, batch_size=128, shuffle=True) + return data_loader, dataset + + def load_data(): transform = transforms.Compose( [ @@ -474,7 +536,6 @@ def complexity_measure( model_lb: nn.Module, img: Tensor, epsilon=0.4, - save_preliminary=False, ): model_gb.eval() model_lb.eval() @@ -488,25 +549,6 @@ def complexity_measure( recon_bits_lb = recon_lb.view(-1, 64, 64).cpu() > epsilon mask_bits = mask[0].cpu() > 0 - if save_preliminary: - save_image( - torch.stack( - [mask_bits.float(), recon_bits_gb.float(), recon_bits_lb.float()] - ).cpu(), - f"shape_complexity/results/mask_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png", - ) - save_image( - torch.stack( - [ - (mask_bits & recon_bits_gb).float(), - (recon_bits_gb & ~mask_bits).float(), - (mask_bits & recon_bits_lb).float(), - (recon_bits_lb & ~mask_bits).float(), - ] - ).cpu(), - f"shape_complexity/results/tp_fp_recon{model_gb.bottleneck}_{model_lb.bottleneck}.png", - ) - tp_gb = (mask_bits & recon_bits_gb).sum() fp_gb = (recon_bits_gb & ~mask_bits).sum() tp_lb = (mask_bits & recon_bits_lb).sum() @@ -515,18 +557,9 @@ def complexity_measure( prec_gb = tp_gb / (tp_gb + fp_gb) prec_lb = tp_lb / (tp_lb + fp_lb) complexity = 1 - (prec_gb - np.abs(prec_gb - prec_lb)) - complexity_lb = 1 - prec_lb - complexity_gb = 1 - prec_gb - # 1 - (0.4 - abs(0.4 - 0.7)) = 0.9 - # 1 - 0.7 = 0.3 return ( complexity, - complexity_lb, - complexity_gb, - prec_gb - prec_lb, - prec_lb, - prec_gb, make_grid( torch.stack( [mask[0], recon_lb.view(-1, 64, 64), recon_gb.view(-1, 64, 64)] @@ -730,9 +763,7 @@ def visualize_sort_3dim( c_compress = compression_complexity(mask) c_fft = fft_measure(mask) # TODO: maybe exchange by diff or mean measure instead of precision - c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure( - model_gb, model_lb, mask - ) + c_vae, mask_recon_grid = complexity_measure(model_gb, model_lb, mask) masks_recon[i] = mask_recon_grid masks[i] = mask[0] measures[i] = torch.tensor([c_compress, c_fft, c_vae]) @@ -767,181 +798,55 @@ def visualize_sort(dataset, model_gb: nn.Module, model_lb: nn.Module): masks = torch.zeros((400, 3, 64, 192)) complexities = torch.zeros((400,)) - diffs = [] for i, (mask, _) in enumerate(data_loader, 0): - complexity, _, _, diff, mask_recon_grid = complexity_measure( + complexity, mask_recon_grid = complexity_measure( model_gb, model_lb, mask, save_preliminary=True ) masks[i] = mask_recon_grid - diffs.append(diff) complexities[i] = complexity sort_idx = np.argsort(np.array(complexities)) masks_sorted = masks.numpy()[sort_idx] - plt.plot(np.arange(len(diffs)), np.sort(diffs)) - plt.xlabel("images") - plt.ylabel("prec difference (L-H)") - plt.savefig("shape_complexity/results/diff_plot.png") - plt.clf() - return plot_samples(masks_sorted, complexities[sort_idx]) def visualize_sort_fixed(data_loader, model_gb: nn.Module, model_lb: nn.Module): masks = torch.zeros((400, 3, 64, 192)) complexities = torch.zeros((400,)) - complexities_lb = torch.zeros((400,)) - complexities_gb = torch.zeros((400,)) - diffs = [] - prec_lbs = [] - prec_gbs = [] for i, (mask, _) in enumerate(data_loader, 0): ( complexity, - lb, - gb, - diff, - prec_lb, - prec_gb, mask_recon_grid, ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True) masks[i] = mask_recon_grid - diffs.append(diff) - prec_lbs.append(prec_lb) - prec_gbs.append(prec_gb) complexities[i] = complexity - complexities_lb[i] = lb - complexities_gb[i] = gb sort_idx = np.argsort(np.array(complexities)) - sort_idx_lb = np.argsort(np.array(complexities_lb)) - sort_idx_gb = np.argsort(np.array(complexities_gb)) masks_sorted = masks.numpy()[sort_idx] - masks_sorted_lb = masks.numpy()[sort_idx_lb] - masks_sorted_gb = masks.numpy()[sort_idx_gb] - - diff_sort_idx = np.argsort(diffs) - # plt.savefig("shape_complexity/results/diff_plot.png") - # plt.clf - - fig, ax1 = plt.subplots() - ax2 = ax1.twinx() - ax1.plot( - np.arange(len(prec_lbs)), - np.array(prec_lbs)[diff_sort_idx], - label=f"bottleneck {model_lb.bottleneck}", - ) - ax1.plot( - np.arange(len(prec_gbs)), - np.array(prec_gbs)[diff_sort_idx], - label=f"bottleneck {model_gb.bottleneck}", - ) - ax2.plot( - np.arange(len(diffs)), - np.sort(diffs), - color="red", - label="prec difference (H - L)", - ) - ax1.legend(loc="lower left") - ax2.legend(loc="lower right") - ax1.set_ylabel("precision") - ax2.set_ylabel("prec difference (H-L)") - plt.savefig("shape_complexity/results/prec_plot.png") - plt.clf() fig = plot_samples(masks_sorted, complexities[sort_idx]) fig.savefig("shape_complexity/results/abs.png") plt.close(fig) - fig = plot_samples(masks_sorted_lb, complexities_lb[sort_idx_lb]) - fig.savefig("shape_complexity/results/lb.png") - plt.close(fig) - fig = plot_samples(masks_sorted_gb, complexities_gb[sort_idx_gb]) - fig.savefig("shape_complexity/results/gb.png") - plt.close(fig) def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module): recon_masks = torch.zeros((400, 3, 64, 192)) masks = torch.zeros((400, 1, 64, 64)) complexities = torch.zeros((400,)) - diffs = np.zeros((400,)) - prec_gbs = np.zeros((400,)) - prec_lbs = np.zeros((400,)) for i, (mask, _) in enumerate(data_loader, 0): ( complexity, - _, - _, - diff, - prec_lb, - prec_gb, mask_recon_grid, ) = complexity_measure(model_gb, model_lb, mask, save_preliminary=True) recon_masks[i] = mask_recon_grid masks[i] = mask[0] - diffs[i] = diff - prec_gbs[i] = prec_gb - prec_lbs[i] = prec_lb complexities[i] = complexity sort_idx = np.argsort(np.array(complexities)) masks_sorted = masks.numpy()[sort_idx] recon_masks_sorted = recon_masks.numpy()[sort_idx] - # group_labels = ["lte_0", "gt_0_lte0.05", "gt_0.05"] - # bin_edges = [-np.inf, 0.0, 0.05, np.inf] - # bins = np.digitize(diffs, bins=bin_edges, right=True) - - # for i in range(bins.min(), bins.max() + 1): - # bin_idx = bins == i - # binned_prec_gb = prec_gbs[bin_idx] - # prec_mean = binned_prec_gb.mean() - # prec_idx = prec_gbs > prec_mean - - # binned_masks_high = recon_masks[bin_idx & prec_idx] - # binned_masks_low = recon_masks[bin_idx & ~prec_idx] - - # save_image( - # binned_masks_high, - # f"shape_complexity/results/diff_{group_labels[i-1]}_high.png", - # padding=10, - # ) - # save_image( - # binned_masks_low, - # f"shape_complexity/results/diff_{group_labels[i-1]}_low.png", - # padding=10, - # ) - - # diff_sort_idx = np.argsort(diffs) - - # fig, ax1 = plt.subplots() - # ax2 = ax1.twinx() - # ax1.plot( - # np.arange(len(prec_lbs)), - # np.array(prec_lbs)[diff_sort_idx], - # label=f"bottleneck {model_lb.bottleneck}", - # ) - # ax1.plot( - # np.arange(len(prec_gbs)), - # np.array(prec_gbs)[diff_sort_idx], - # label=f"bottleneck {model_gb.bottleneck}", - # ) - # ax2.plot( - # np.arange(len(diffs)), - # np.sort(diffs), - # color="red", - # label="prec difference (H - L)", - # ) - # ax1.legend(loc="lower left") - # ax2.legend(loc="lower right") - # ax1.set_ylabel("precision") - # ax2.set_ylabel("prec difference (H-L)") - # ax1.set_xlabel("image") - # plt.savefig("shape_complexity/results/prec_plot.png") - # plt.tight_layout(pad=2) - # plt.clf() - fig = plot_samples(recon_masks_sorted, complexities[sort_idx]) fig.savefig("shape_complexity/results/abs_recon.png") plt.close(fig) @@ -951,9 +856,15 @@ def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module): plt.close(fig) -LR = 1e-3 -EPOCHS = 10 -LOAD_PRETRAINED = True +LR = 1.5e-3 +EPOCHS = 100 +LOAD_PRETRAINED = False + + +# TODO: build pixel ratio normalization for large black areas +# -> ideally, this fixes the compression metric +# TODO: try out pixelwise loss again (in 3d as well) +# TODO: might be a good idea to implement a bbox cut preprocessing transform thingy def main(): @@ -961,7 +872,8 @@ def main(): models = {i: CONVVAE(bottleneck=i).to(device) for i in bottlenecks} optimizers = {i: Adam(model.parameters(), lr=LR) for i, model in models.items()} - data_loader, dataset = load_data() + # data_loader, dataset = load_data() + data_loader, dataset = load_mpeg7_data() train_size = int(0.8 * len(dataset)) test_size = len(dataset) - train_size