Skip to content
Snippets Groups Projects
Commit 2887232f authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent d8cb6252
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ import numpy.typing as npt
import torch
import torch.nn.functional as F
from PIL import Image
from torch import Tensor, nn
from torch import Tensor, conv2d, nn
from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler
from torchvision.datasets import ImageFolder
......@@ -142,11 +142,29 @@ class CONVVAE(nn.Module):
super(CONVVAE, self).__init__()
self.bottleneck = bottleneck
self.feature_dim = 32 * 56 * 56
self.feature_dim = 128
self.conv = nn.Sequential(
nn.Conv2d(1, 16, 5), nn.ReLU(), nn.Conv2d(16, 32, 5), nn.ReLU()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 5),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 30x30x16
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 3),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 14x14x32
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 3),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 6x6x64
)
self.conv4 = nn.Sequential(
nn.Conv2d(64, 128, 5),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 1x1x128
)
self.encode_mu = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
)
......@@ -154,26 +172,32 @@ class CONVVAE(nn.Module):
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
)
self.decode_linear = nn.Sequential(
nn.Linear(self.bottleneck, self.feature_dim),
self.decode = nn.Sequential(
nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(128, 64, 5),
nn.ReLU(),
)
self.decode_conv = nn.Sequential(
nn.ConvTranspose2d(32, 16, 5),
nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(64, 32, 3),
nn.ReLU(),
nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(32, 16, 3),
nn.ReLU(),
nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(16, 1, 5),
nn.Sigmoid(),
)
def encode(self, x):
x = self.conv(x)
return self.encode_mu(x), self.encode_logvar(x)
x, indices = self.conv(x)
mu = self.encode_mu(x)
logvar = self.encode_logvar(x)
return mu, logvar
def decode(self, z):
z = self.decode_linear(z)
z = z.view(-1, 32, 56, 56)
return self.decode_conv(z)
# def decode(self, z):
# z = self.decode_linear(z)
# # z = z.view(-1, 128, 1, 1)
# return self.decode_conv(z)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
......@@ -752,7 +776,7 @@ def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
LR = 1e-3
EPOCHS = 20
LOAD_PRETRAINED = True
LOAD_PRETRAINED = False
def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment