Skip to content
Snippets Groups Projects
Commit 6451b30b authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent eaa0af94
No related branches found
No related tags found
No related merge requests found
...@@ -140,7 +140,8 @@ class VAE(nn.Module): ...@@ -140,7 +140,8 @@ class VAE(nn.Module):
class CONVVAE(nn.Module): class CONVVAE(nn.Module):
def __init__( def __init__(
self, bottleneck=2, self,
bottleneck=2,
): ):
super(CONVVAE, self).__init__() super(CONVVAE, self).__init__()
...@@ -148,13 +149,19 @@ class CONVVAE(nn.Module): ...@@ -148,13 +149,19 @@ class CONVVAE(nn.Module):
self.feature_dim = 6 * 6 * 64 self.feature_dim = 6 * 6 * 64
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 5), nn.ReLU(), nn.MaxPool2d((2, 2)), # -> 30x30x16 nn.Conv2d(1, 16, 5),
nn.ReLU(),
nn.MaxPool2d((2, 2)), # -> 30x30x16
) )
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 3), nn.ReLU(), nn.MaxPool2d((2, 2)), # -> 14x14x32 nn.Conv2d(16, 32, 3),
nn.ReLU(),
nn.MaxPool2d((2, 2)), # -> 14x14x32
) )
self.conv3 = nn.Sequential( self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 3), nn.ReLU(), nn.MaxPool2d((2, 2)), # -> 6x6x64 nn.Conv2d(32, 64, 3),
nn.ReLU(),
nn.MaxPool2d((2, 2)), # -> 6x6x64
) )
# self.conv4 = nn.Sequential( # self.conv4 = nn.Sequential(
# nn.Conv2d(32, self.bottleneck, 5), # nn.Conv2d(32, self.bottleneck, 5),
...@@ -163,7 +170,8 @@ class CONVVAE(nn.Module): ...@@ -163,7 +170,8 @@ class CONVVAE(nn.Module):
# ) # )
self.encode_mu = nn.Sequential( self.encode_mu = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck), nn.Flatten(),
nn.Linear(self.feature_dim, self.bottleneck),
) )
self.encode_logvar = nn.Sequential( self.encode_logvar = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck) nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
...@@ -273,7 +281,7 @@ def load_data(): ...@@ -273,7 +281,7 @@ def load_data():
trajectories = ( trajectories = (
[] []
if True if False
else [ else [
# "v3_subtle_iceberg_lettuce_nymph-6_203-2056", # "v3_subtle_iceberg_lettuce_nymph-6_203-2056",
"v3_absolute_grape_changeling-16_2277-4441", "v3_absolute_grape_changeling-16_2277-4441",
...@@ -441,20 +449,6 @@ def compression_complexity(img: Tensor): ...@@ -441,20 +449,6 @@ def compression_complexity(img: Tensor):
return len(compressed) return len(compressed)
# https://stackoverflow.com/questions/21242011/most-efficient-way-to-calculate-radial-profile
def radial_profile(data, center=None):
y, x = np.indices((data.shape))
center = np.array([(x.max() - x.min()) / 2.0, (y.max() - y.min()) / 2.0])
r = np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2)
r = r.astype(np.int)
tbin = np.bincount(r.ravel(), data.ravel())
nr = np.bincount(r.ravel())
radialprofile = tbin / nr
return radialprofile
def fft_measure(img: Tensor): def fft_measure(img: Tensor):
np_img = img[0][0].numpy() np_img = img[0][0].numpy()
fft = np.fft.fft2(np_img) fft = np.fft.fft2(np_img)
...@@ -465,11 +459,10 @@ def fft_measure(img: Tensor): ...@@ -465,11 +459,10 @@ def fft_measure(img: Tensor):
pos_f_idx = n // 2 pos_f_idx = n // 2
df = np.fft.fftfreq(n=n) df = np.fft.fftfreq(n=n)
norm = fft_abs[:pos_f_idx, :pos_f_idx].sum() amplitude_sum = fft_abs[:pos_f_idx, :pos_f_idx].sum()
mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / norm mean_x_freq = (fft_abs * df)[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / norm mean_y_freq = (fft_abs.T * df).T[:pos_f_idx, :pos_f_idx].sum() / amplitude_sum
# unidirectional amplitudes
mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2)) mean_freq = np.sqrt(np.power(mean_x_freq, 2) + np.power(mean_y_freq, 2))
# mean frequency in range 0 to 0.5 # mean frequency in range 0 to 0.5
...@@ -565,7 +558,9 @@ def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4): ...@@ -565,7 +558,9 @@ def mean_precision(models: list[nn.Module], img: Tensor, epsilon=0.4):
def complexity_measure_diff( def complexity_measure_diff(
model_gb: nn.Module, model_lb: nn.Module, img: Tensor, model_gb: nn.Module,
model_lb: nn.Module,
img: Tensor,
): ):
model_gb.eval() model_gb.eval()
model_lb.eval() model_lb.eval()
...@@ -734,7 +729,7 @@ def visualize_sort_3dim( ...@@ -734,7 +729,7 @@ def visualize_sort_3dim(
for i, (mask, _) in enumerate(data_loader, 0): for i, (mask, _) in enumerate(data_loader, 0):
c_compress = compression_complexity(mask) c_compress = compression_complexity(mask)
c_fft = fft_measure(mask) c_fft = fft_measure(mask)
# TODO: maybe exchange by diff measure instead of precision # TODO: maybe exchange by diff or mean measure instead of precision
c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure( c_vae, _, _, _, _, _, mask_recon_grid = complexity_measure(
model_gb, model_lb, mask model_gb, model_lb, mask
) )
...@@ -746,6 +741,7 @@ def visualize_sort_3dim( ...@@ -746,6 +741,7 @@ def visualize_sort_3dim(
measure_norm = torch.linalg.vector_norm(measures, dim=1) measure_norm = torch.linalg.vector_norm(measures, dim=1)
fig = plt.figure() fig = plt.figure()
fig.clf()
ax = fig.add_subplot(projection="3d") ax = fig.add_subplot(projection="3d")
ax.scatter(measures[:, 0], measures[:, 1], measures[:, 2], marker="o") ax.scatter(measures[:, 0], measures[:, 1], measures[:, 2], marker="o")
...@@ -753,7 +749,7 @@ def visualize_sort_3dim( ...@@ -753,7 +749,7 @@ def visualize_sort_3dim(
ax.set_ylabel("FFT ratio") ax.set_ylabel("FFT ratio")
ax.set_zlabel(f"VAE ratio {model_gb.bottleneck}/{model_lb.bottleneck}") ax.set_zlabel(f"VAE ratio {model_gb.bottleneck}/{model_lb.bottleneck}")
plt.savefig("shape_complexity/results/3d_plot.png") plt.savefig("shape_complexity/results/3d_plot.png")
plt.clf() plt.close()
sort_idx = np.argsort(np.array(measure_norm)) sort_idx = np.argsort(np.array(measure_norm))
recon_masks_sorted = masks_recon.numpy()[sort_idx] recon_masks_sorted = masks_recon.numpy()[sort_idx]
...@@ -974,36 +970,36 @@ def main(): ...@@ -974,36 +970,36 @@ def main():
) )
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True) train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
# if LOAD_PRETRAINED: if LOAD_PRETRAINED:
# for i, model in models.items(): for i, model in models.items():
# model.load_state_dict( model.load_state_dict(
# torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth") torch.load(f"shape_complexity/trained/CONVVAE_{i}_split_data.pth")
# ) )
# else: else:
# for epoch in range(EPOCHS): for epoch in range(EPOCHS):
# for i, model in models.items(): for i, model in models.items():
# train( train(
# epoch, epoch,
# model=model, model=model,
# optimizer=optimizers[i], optimizer=optimizers[i],
# data_loader=train_loader, data_loader=train_loader,
# ) )
# test(epoch, models=list(models.values()), dataset=test_dataset) test(epoch, models=list(models.values()), dataset=test_dataset)
# for bn in bottlenecks: for bn in bottlenecks:
# if not os.path.exists("shape_complexity/trained"): if not os.path.exists("shape_complexity/trained"):
# os.makedirs("shape_complexity/trained") os.makedirs("shape_complexity/trained")
# torch.save( torch.save(
# models[bn].state_dict(), models[bn].state_dict(),
# f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth", f"shape_complexity/trained/CONVVAE_{bn}_split_data.pth",
# ) )
# test(0, models=list(models.values()), dataset=test_dataset, save_results=True) test(0, models=list(models.values()), dataset=test_dataset, save_results=True)
bn_gt = 32 bn_gt = 32
bn_lt = 4 bn_lt = 8
# for i in range(10): # for i in range(10):
# figure = visualize_sort(dataset, models[bn_gt], models[bn_lt]) # figure = visualize_sort(dataset, models[bn_gt], models[bn_lt])
...@@ -1019,32 +1015,32 @@ def main(): ...@@ -1019,32 +1015,32 @@ def main():
sampler = RandomSampler(dataset, replacement=True, num_samples=400) sampler = RandomSampler(dataset, replacement=True, num_samples=400)
data_loader = DataLoader(dataset, batch_size=1, sampler=sampler) data_loader = DataLoader(dataset, batch_size=1, sampler=sampler)
# visualize_sort_group(data_loader, models[bn_gt], models[bn_lt]) visualize_sort_group(data_loader, models[bn_gt], models[bn_lt])
# visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt]) # visualize_sort_fixed(data_loader, models[bn_gt], models[bn_lt])
# fig, fig_recon = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt]) fig, fig_recon = visualize_sort_3dim(data_loader, models[bn_gt], models[bn_lt])
# fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png") fig.savefig(f"shape_complexity/results/sort_comp_fft_prec.png")
# fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png") fig_recon.savefig(f"shape_complexity/results/recon_sort_comp_fft_prec.png")
# plt.close(fig) plt.close(fig)
# plt.close(fig_recon) plt.close(fig_recon)
# fig = visualize_sort_mean_precision(list(models.values()), data_loader) fig = visualize_sort_mean_precision(list(models.values()), data_loader)
# fig.savefig(f"shape_complexity/results/sort_mean_prec.png") fig.savefig(f"shape_complexity/results/sort_mean_prec.png")
# plt.close(fig) plt.close(fig)
fig = visualize_sort_fft(data_loader) fig = visualize_sort_fft(data_loader)
fig.savefig(f"shape_complexity/results/sort_fft.png") fig.savefig(f"shape_complexity/results/sort_fft.png")
plt.close(fig) plt.close(fig)
fig = visualize_sort_compression(data_loader) fig = visualize_sort_compression(data_loader)
fig.savefig(f"shape_complexity/results/sort_compression.png") fig.savefig(f"shape_complexity/results/sort_compression.png")
# fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt]) fig, fig_recon = visualize_sort_mean(data_loader, models[bn_gt])
# fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png") fig.savefig(f"shape_complexity/results/sort_mean_bn{bn_gt}.png")
# fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png") fig_recon.savefig(f"shape_complexity/results/recon_sort_mean_bn{bn_gt}.png")
# plt.close(fig) plt.close(fig)
# plt.close(fig_recon) plt.close()
# fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt]) # fig, fig_recon = visualize_sort_diff(data_loader, models[bn_gt], models[bn_lt])
# fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png") # fig.savefig(f"shape_complexity/results/sort_diff_bn{bn_gt}_bn{bn_lt}.png")
# fig_recon.savefig( # fig_recon.savefig(
# f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png" # f"shape_complexity/results/recon_sort_diff_bn{bn_gt}_bn{bn_lt}.png"
# ) # )
plt.close(fig) # plt.close(fig)
# plt.close(fig_recon) # plt.close(fig_recon)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment