Skip to content
Snippets Groups Projects
Commit d7e441c1 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

discriminator is can now be configured to use variable input sizes and write documentation

parent 9b8cb709
No related branches found
No related tags found
No related merge requests found
import torch import torch
import torch.nn as nn import torch.nn as nn
class Conv(nn.Sequential): class Conv(nn.Sequential):
"""
A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
"""
def __init__(self, in_channels, out_channels): def __init__(self, in_channels, out_channels):
"""
Create a convolutional layer with batch normalization and a ReLU activation function.
:param in_channels: number of channels receives as input
:param out_channels: number of filters and consequently channels in the output
"""
super().__init__() super().__init__()
self.append(nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same")) self.append(
nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding="same",
)
)
self.append(nn.BatchNorm3d(num_features=out_channels)) self.append(nn.BatchNorm3d(num_features=out_channels))
self.append(nn.ReLU(inplace=True)) self.append(nn.ReLU(inplace=True))
class Discriminator(nn.Module): class Discriminator(nn.Module):
"""
Create the discriminator as specified by Shi et al. (2020).
If consists of three convolutional layers with max pooling, followed by three fully connected layers.
"""
def __init__(self, in_channels=1): def __init__(self, in_channels=1, input_size=16):
"""
Create the discriminator.
:param in_channels: number channels received as an input
:param input_size: shape of the inputs images which is required to compute the number of features in the first fully connected layer
"""
super().__init__() super().__init__()
#TODO: make fully connected layer dependent on input shape # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3)
#TODO: write doc fc_input_size = (input_size // 2 ** 3) ** 3
self.conv = nn.Sequential( self.conv = nn.Sequential(
Conv(in_channels=in_channels, out_channels=32), Conv(in_channels=in_channels, out_channels=32),
...@@ -26,13 +55,13 @@ class Discriminator(nn.Module): ...@@ -26,13 +55,13 @@ class Discriminator(nn.Module):
nn.MaxPool3d(kernel_size=2, stride=2), nn.MaxPool3d(kernel_size=2, stride=2),
) )
self.fully_connected = nn.Sequential( self.fully_connected = nn.Sequential(
nn.Linear(in_features=128 * 2 ** 3, out_features=512), nn.Linear(in_features=128 * fc_input_size, out_features=512),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(in_features=512, out_features=128), nn.Linear(in_features=512, out_features=128),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Linear(in_features=128, out_features=1), nn.Linear(in_features=128, out_features=1),
) )
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = torch.flatten(x, 1) x = torch.flatten(x, 1)
...@@ -41,10 +70,12 @@ class Discriminator(nn.Module): ...@@ -41,10 +70,12 @@ class Discriminator(nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
net = Discriminator() input_size = 16
net = Discriminator(input_size=input_size)
print(net) print(net)
_inputs = torch.rand((1, 1, 16, 16, 16)) _inputs = torch.rand((1, 1, input_size, input_size, input_size))
_outputs = net(_inputs) _outputs = net(_inputs)
print(f"Transform {_inputs.shape} to {_outputs.shape}") print(f"Transform {_inputs.shape} to {_outputs.shape}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment