diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index 68dbafa95408a816895b8b95d13c3becf4023169..c57c94c7633371428ac6de6bbdc6b645213e97d1 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -1,21 +1,50 @@ import torch import torch.nn as nn + class Conv(nn.Sequential): + """ + A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function. + """ def __init__(self, in_channels, out_channels): + """ + Create a convolutional layer with batch normalization and a ReLU activation function. + + :param in_channels: number of channels receives as input + :param out_channels: number of filters and consequently channels in the output + """ super().__init__() - self.append(nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same")) + self.append( + nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding="same", + ) + ) self.append(nn.BatchNorm3d(num_features=out_channels)) self.append(nn.ReLU(inplace=True)) + class Discriminator(nn.Module): + """ + Create the discriminator as specified by Shi et al. (2020). + If consists of three convolutional layers with max pooling, followed by three fully connected layers. + """ - def __init__(self, in_channels=1): + def __init__(self, in_channels=1, input_size=16): + """ + Create the discriminator. + + :param in_channels: number channels received as an input + :param input_size: shape of the inputs images which is required to compute the number of features in the first fully connected layer + """ super().__init__() - #TODO: make fully connected layer dependent on input shape - #TODO: write doc + # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3) + fc_input_size = (input_size // 2 ** 3) ** 3 self.conv = nn.Sequential( Conv(in_channels=in_channels, out_channels=32), @@ -26,13 +55,13 @@ class Discriminator(nn.Module): nn.MaxPool3d(kernel_size=2, stride=2), ) self.fully_connected = nn.Sequential( - nn.Linear(in_features=128 * 2 ** 3, out_features=512), + nn.Linear(in_features=128 * fc_input_size, out_features=512), nn.ReLU(inplace=True), nn.Linear(in_features=512, out_features=128), nn.ReLU(inplace=True), nn.Linear(in_features=128, out_features=1), ) - + def forward(self, x): x = self.conv(x) x = torch.flatten(x, 1) @@ -41,10 +70,12 @@ class Discriminator(nn.Module): if __name__ == "__main__": - net = Discriminator() + input_size = 16 + + net = Discriminator(input_size=input_size) print(net) - _inputs = torch.rand((1, 1, 16, 16, 16)) + _inputs = torch.rand((1, 1, input_size, input_size, input_size)) _outputs = net(_inputs) print(f"Transform {_inputs.shape} to {_outputs.shape}")