Skip to content
Snippets Groups Projects
Commit 2887232f authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent d8cb6252
No related branches found
No related tags found
No related merge requests found
...@@ -7,7 +7,7 @@ import numpy.typing as npt ...@@ -7,7 +7,7 @@ import numpy.typing as npt
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image
from torch import Tensor, nn from torch import Tensor, conv2d, nn
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from torchvision.datasets import ImageFolder from torchvision.datasets import ImageFolder
...@@ -142,11 +142,29 @@ class CONVVAE(nn.Module): ...@@ -142,11 +142,29 @@ class CONVVAE(nn.Module):
super(CONVVAE, self).__init__() super(CONVVAE, self).__init__()
self.bottleneck = bottleneck self.bottleneck = bottleneck
self.feature_dim = 32 * 56 * 56 self.feature_dim = 128
self.conv = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 5), nn.ReLU(), nn.Conv2d(16, 32, 5), nn.ReLU() nn.Conv2d(1, 16, 5),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 30x30x16
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 3),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 14x14x32
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 64, 3),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 6x6x64
)
self.conv4 = nn.Sequential(
nn.Conv2d(64, 128, 5),
nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 1x1x128
) )
self.encode_mu = nn.Sequential( self.encode_mu = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck) nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
) )
...@@ -154,26 +172,32 @@ class CONVVAE(nn.Module): ...@@ -154,26 +172,32 @@ class CONVVAE(nn.Module):
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck) nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck)
) )
self.decode_linear = nn.Sequential( self.decode = nn.Sequential(
nn.Linear(self.bottleneck, self.feature_dim), nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(128, 64, 5),
nn.ReLU(), nn.ReLU(),
) nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(64, 32, 3),
self.decode_conv = nn.Sequential(
nn.ConvTranspose2d(32, 16, 5),
nn.ReLU(), nn.ReLU(),
nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(32, 16, 3),
nn.ReLU(),
nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(16, 1, 5), nn.ConvTranspose2d(16, 1, 5),
nn.Sigmoid(), nn.Sigmoid(),
) )
def encode(self, x): def encode(self, x):
x = self.conv(x) x, indices = self.conv(x)
return self.encode_mu(x), self.encode_logvar(x) mu = self.encode_mu(x)
logvar = self.encode_logvar(x)
return mu, logvar
def decode(self, z): # def decode(self, z):
z = self.decode_linear(z) # z = self.decode_linear(z)
z = z.view(-1, 32, 56, 56) # # z = z.view(-1, 128, 1, 1)
return self.decode_conv(z) # return self.decode_conv(z)
def reparameterize(self, mu, logvar): def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar) std = torch.exp(0.5 * logvar)
...@@ -752,7 +776,7 @@ def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module): ...@@ -752,7 +776,7 @@ def visualize_sort_group(data_loader, model_gb: nn.Module, model_lb: nn.Module):
LR = 1e-3 LR = 1e-3
EPOCHS = 20 EPOCHS = 20
LOAD_PRETRAINED = True LOAD_PRETRAINED = False
def main(): def main():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment