From d7e441c1e22e2f409b9d53b8472ea5b9e41a5e93 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Wed, 17 Aug 2022 13:05:12 +0200
Subject: [PATCH] discriminator is can now be configured to use variable input
 sizes and write documentation

---
 mu_map/models/discriminator.py | 47 ++++++++++++++++++++++++++++------
 1 file changed, 39 insertions(+), 8 deletions(-)

diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index 68dbafa..c57c94c 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -1,21 +1,50 @@
 import torch
 import torch.nn as nn
 
+
 class Conv(nn.Sequential):
+    """
+    A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
+    """
 
     def __init__(self, in_channels, out_channels):
+        """
+        Create a convolutional layer with batch normalization and a ReLU activation function.
+
+        :param in_channels: number of channels receives as input
+        :param out_channels: number of filters and consequently channels in the output
+        """
         super().__init__()
 
-        self.append(nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same"))
+        self.append(
+            nn.Conv3d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=3,
+                stride=1,
+                padding="same",
+            )
+        )
         self.append(nn.BatchNorm3d(num_features=out_channels))
         self.append(nn.ReLU(inplace=True))
 
+
 class Discriminator(nn.Module):
+    """
+    Create the discriminator as specified by Shi et al. (2020).
+    If consists of three convolutional layers with max pooling, followed by three fully connected layers.
+    """
 
-    def __init__(self, in_channels=1):
+    def __init__(self, in_channels=1, input_size=16):
+        """
+        Create the discriminator.
+
+        :param in_channels: number channels received as an input
+        :param input_size: shape of the inputs images which is required to compute the number of features in the first fully connected layer
+        """
         super().__init__()
-        #TODO: make fully connected layer dependent on input shape
-        #TODO: write doc
+        # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3)
+        fc_input_size = (input_size // 2 ** 3) ** 3
 
         self.conv = nn.Sequential(
             Conv(in_channels=in_channels, out_channels=32),
@@ -26,13 +55,13 @@ class Discriminator(nn.Module):
             nn.MaxPool3d(kernel_size=2, stride=2),
         )
         self.fully_connected = nn.Sequential(
-            nn.Linear(in_features=128 * 2 ** 3, out_features=512),
+            nn.Linear(in_features=128 * fc_input_size, out_features=512),
             nn.ReLU(inplace=True),
             nn.Linear(in_features=512, out_features=128),
             nn.ReLU(inplace=True),
             nn.Linear(in_features=128, out_features=1),
         )
-    
+
     def forward(self, x):
         x = self.conv(x)
         x = torch.flatten(x, 1)
@@ -41,10 +70,12 @@ class Discriminator(nn.Module):
 
 
 if __name__ == "__main__":
-    net = Discriminator()
+    input_size = 16
+
+    net = Discriminator(input_size=input_size)
     print(net)
 
-    _inputs = torch.rand((1, 1, 16, 16, 16))
+    _inputs = torch.rand((1, 1, input_size, input_size, input_size))
     _outputs = net(_inputs)
 
     print(f"Transform {_inputs.shape} to {_outputs.shape}")
-- 
GitLab