From d7e441c1e22e2f409b9d53b8472ea5b9e41a5e93 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Wed, 17 Aug 2022 13:05:12 +0200 Subject: [PATCH] discriminator is can now be configured to use variable input sizes and write documentation --- mu_map/models/discriminator.py | 47 ++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 8 deletions(-) diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index 68dbafa..c57c94c 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -1,21 +1,50 @@ import torch import torch.nn as nn + class Conv(nn.Sequential): + """ + A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function. + """ def __init__(self, in_channels, out_channels): + """ + Create a convolutional layer with batch normalization and a ReLU activation function. + + :param in_channels: number of channels receives as input + :param out_channels: number of filters and consequently channels in the output + """ super().__init__() - self.append(nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same")) + self.append( + nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding="same", + ) + ) self.append(nn.BatchNorm3d(num_features=out_channels)) self.append(nn.ReLU(inplace=True)) + class Discriminator(nn.Module): + """ + Create the discriminator as specified by Shi et al. (2020). + If consists of three convolutional layers with max pooling, followed by three fully connected layers. + """ - def __init__(self, in_channels=1): + def __init__(self, in_channels=1, input_size=16): + """ + Create the discriminator. + + :param in_channels: number channels received as an input + :param input_size: shape of the inputs images which is required to compute the number of features in the first fully connected layer + """ super().__init__() - #TODO: make fully connected layer dependent on input shape - #TODO: write doc + # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3) + fc_input_size = (input_size // 2 ** 3) ** 3 self.conv = nn.Sequential( Conv(in_channels=in_channels, out_channels=32), @@ -26,13 +55,13 @@ class Discriminator(nn.Module): nn.MaxPool3d(kernel_size=2, stride=2), ) self.fully_connected = nn.Sequential( - nn.Linear(in_features=128 * 2 ** 3, out_features=512), + nn.Linear(in_features=128 * fc_input_size, out_features=512), nn.ReLU(inplace=True), nn.Linear(in_features=512, out_features=128), nn.ReLU(inplace=True), nn.Linear(in_features=128, out_features=1), ) - + def forward(self, x): x = self.conv(x) x = torch.flatten(x, 1) @@ -41,10 +70,12 @@ class Discriminator(nn.Module): if __name__ == "__main__": - net = Discriminator() + input_size = 16 + + net = Discriminator(input_size=input_size) print(net) - _inputs = torch.rand((1, 1, 16, 16, 16)) + _inputs = torch.rand((1, 1, input_size, input_size, input_size)) _outputs = net(_inputs) print(f"Transform {_inputs.shape} to {_outputs.shape}") -- GitLab