Skip to content
Snippets Groups Projects
Commit c7936a67 authored by Markus Rothgänger's avatar Markus Rothgänger
Browse files

wip

parent 2887232f
No related branches found
No related tags found
No related merge requests found
...@@ -142,7 +142,7 @@ class CONVVAE(nn.Module): ...@@ -142,7 +142,7 @@ class CONVVAE(nn.Module):
super(CONVVAE, self).__init__() super(CONVVAE, self).__init__()
self.bottleneck = bottleneck self.bottleneck = bottleneck
self.feature_dim = 128 self.feature_dim = 32
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv2d(1, 16, 5), nn.Conv2d(1, 16, 5),
...@@ -160,44 +160,65 @@ class CONVVAE(nn.Module): ...@@ -160,44 +160,65 @@ class CONVVAE(nn.Module):
nn.MaxPool2d((2, 2), return_indices=True), # -> 6x6x64 nn.MaxPool2d((2, 2), return_indices=True), # -> 6x6x64
) )
self.conv4 = nn.Sequential( self.conv4 = nn.Sequential(
nn.Conv2d(64, 128, 5), nn.Conv2d(64, 2 * self.bottleneck, 5),
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d((2, 2), return_indices=True), # -> 1x1x128 nn.MaxPool2d((2, 2), return_indices=True), # -> 1x1x2*bottleneck
) )
self.encode_mu = nn.Sequential( self.encode_mu = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck) nn.Flatten(),
nn.Linear(
2 * self.bottleneck, self.bottleneck
), # TODO: maybe only FC from bn x bn
) )
self.encode_logvar = nn.Sequential( self.encode_logvar = nn.Sequential(
nn.Flatten(), nn.Linear(self.feature_dim, self.bottleneck) nn.Flatten(), nn.Linear(2 * self.bottleneck, self.bottleneck)
) )
self.decode = nn.Sequential( self.decode_linear = nn.Linear(self.bottleneck, 2 * self.bottleneck)
nn.MaxUnpool2d((2, 2)),
nn.ConvTranspose2d(128, 64, 5), self.decode4 = nn.Sequential(
nn.ConvTranspose2d(2 * self.bottleneck, 64, 5),
nn.ReLU(), nn.ReLU(),
nn.MaxUnpool2d((2, 2)), )
self.decode3 = nn.Sequential(
nn.ConvTranspose2d(64, 32, 3), nn.ConvTranspose2d(64, 32, 3),
nn.ReLU(), nn.ReLU(),
nn.MaxUnpool2d((2, 2)), )
self.decode2 = nn.Sequential(
nn.ConvTranspose2d(32, 16, 3), nn.ConvTranspose2d(32, 16, 3),
nn.ReLU(), nn.ReLU(),
nn.MaxUnpool2d((2, 2)), )
self.decode1 = nn.Sequential(
nn.ConvTranspose2d(16, 1, 5), nn.ConvTranspose2d(16, 1, 5),
nn.Sigmoid(), nn.Sigmoid(),
) )
def encode(self, x): def encode(self, x):
x, indices = self.conv(x) x, idx1 = self.conv1(x)
x, idx2 = self.conv2(x)
x, idx3 = self.conv3(x)
x, idx4 = self.conv4(x)
mu = self.encode_mu(x) mu = self.encode_mu(x)
logvar = self.encode_logvar(x) logvar = self.encode_logvar(x)
return mu, logvar return mu, logvar, (idx1, idx2, idx3, idx4)
# def decode(self, z): def decode(self, z: Tensor, indexes: tuple):
# z = self.decode_linear(z) (idx1, idx2, idx3, idx4) = indexes
# # z = z.view(-1, 128, 1, 1) z = self.decode_linear(z)
# return self.decode_conv(z) z = z.view((-1, 2 * self.bottleneck, 1, 1))
z = F.max_unpool2d(z, idx4, (2, 2))
z = self.decode4(z)
z = F.max_unpool2d(z, idx3, (2, 2))
z = self.decode3(z)
z = F.max_unpool2d(z, idx2, (2, 2))
z = self.decode2(z)
z = F.max_unpool2d(z, idx1, (2, 2))
z = self.decode1(z)
# z = z.view(-1, 128, 1, 1)
# return self.decode_conv(z)
return z
def reparameterize(self, mu, logvar): def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar) std = torch.exp(0.5 * logvar)
...@@ -205,9 +226,9 @@ class CONVVAE(nn.Module): ...@@ -205,9 +226,9 @@ class CONVVAE(nn.Module):
return mu + eps * std return mu + eps * std
def forward(self, x): def forward(self, x):
mu, logvar = self.encode(x) mu, logvar, indexes = self.encode(x)
z = self.reparameterize(mu, logvar) z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar return self.decode(z, indexes), mu, logvar
def loss(self, recon_x, x, mu, logvar): def loss(self, recon_x, x, mu, logvar):
"""https://github.com/pytorch/examples/blob/main/vae/main.py""" """https://github.com/pytorch/examples/blob/main/vae/main.py"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment