From 364079bab9a898f0f95c117c1219c8675d4dd517 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 6 Jan 2023 11:13:38 +0100
Subject: [PATCH] re-structure discriminator implementation

---
 mu_map/models/discriminator.py | 266 ++++++++++++++++++++++++---------
 1 file changed, 198 insertions(+), 68 deletions(-)

diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index 0409463..dd90d35 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -1,5 +1,5 @@
 from functools import reduce
-from typing import Tuple, Union
+from typing import List, Optional, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -7,17 +7,21 @@ import torch.nn as nn
 
 class Conv(nn.Sequential):
     """
-    A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
+    A wrapper around a 3D convolutional layer that also contains batch normalization, a ReLU activation function and a max pooling.
     """
 
-    def __init__(self, in_channels: int, out_channels: int):
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+    ):
         """
-        Create a convolutional layer with batch normalization and a ReLU activation function.
+        Create a convolutional layer with batch normalization, a ReLU activation function and max pooling.
 
         Parameters
         ----------
         in_channels: int
-            number of channels receives as input
+            number of channels received as input
         out_channels: int
             number of filters and consequently channels in the output
         """
@@ -34,6 +38,7 @@ class Conv(nn.Sequential):
         )
         self.append(nn.BatchNorm3d(num_features=out_channels))
         self.append(nn.ReLU(inplace=True))
+        self.append(nn.MaxPool3d(kernel_size=2, stride=2))
 
 
 class Discriminator(nn.Module):
@@ -42,7 +47,13 @@ class Discriminator(nn.Module):
     It consists of three convolutional layers with max pooling, followed by three fully connected layers.
     """
 
-    def __init__(self, in_channels: int = 1, input_size: Union[int, Tuple[int]] = 16):
+    def __init__(
+        self,
+        in_channels: int,
+        input_size: Union[int, Tuple[int, int, int]],
+        conv_features: List[int] = [32, 64, 128],
+        fc_features: List[int] = [512, 128],
+    ):
         """
         Create the discriminator.
 
@@ -53,88 +64,206 @@ class Discriminator(nn.Module):
         input_size: int or tuple of int
             shape of the inputs either as an int (if equal in all dimensions) or as a tuple
             this is required to compute the number of features for the first fully connected layer
+        conv_features: list of int
+            each number represents the number of filter in a convolutional layer
+        fc_features: list of int
+            each number represents the number of features in a fully connected layer
+            note that there will always be an additional fully connected layer between these features and the convolutions
         """
         super().__init__()
-        # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3)
-        if type(input_size) is int:
-            fc_input_size = (input_size // 2**3) ** 3
-        elif type(input_size) is tuple:
-            fc_input_size = map(lambda x: x // 2**3, input_size)
-            fc_input_size = reduce(lambda x, y: x * y, fc_input_size)
-        else:
-            raise ValueError(
-                f"Cannot deal with input size {input_size} of type {type(input_size)}"
-            )
 
-        self.conv = nn.Sequential(
-            Conv(in_channels=in_channels, out_channels=32),
-            nn.MaxPool3d(kernel_size=2, stride=2),
-            Conv(in_channels=32, out_channels=64),
-            nn.MaxPool3d(kernel_size=2, stride=2),
-            Conv(in_channels=64, out_channels=128),
-            nn.MaxPool3d(kernel_size=2, stride=2),
-        )
-        self.fully_connected = nn.Sequential(
-            nn.Linear(in_features=128 * fc_input_size, out_features=512),
-            nn.ReLU(inplace=True),
-            nn.Linear(in_features=512, out_features=128),
-            nn.ReLU(inplace=True),
-            nn.Linear(in_features=128, out_features=1),
+        self.in_channels = in_channels
+        self.input_size = input_size if type(input_size) is tuple else (input_size,) * 3
+        self.conv_features = conv_features
+        self.fc_features = fc_features
+
+        conv_layers = []
+        for in_channels, out_channels in zip(
+            [in_channels, *conv_features[:-1]], conv_features
+        ):
+            conv_layers.append(Conv(in_channels, out_channels))
+        self.conv = nn.Sequential(*conv_layers)
+
+        # input is halved by each convolutional layer
+        self.fc_input_size = map(lambda x: x // 2 ** (len(conv_features)), input_size)
+        self.fc_input_size = reduce(lambda x, y: x * y, self.fc_input_size)
+
+        fc_layers = []
+        fc_layers.append(
+            nn.Linear(
+                in_features=self.fc_input_size * conv_features[-1],
+                out_features=fc_features[0],
+            )
         )
+        fc_layers.append(nn.ReLU(inplace=True))
+        for in_channels, out_channels in zip([*fc_features[:-1]], fc_features[1:]):
+            fc_layers.append(
+                nn.Linear(in_features=in_channels, out_features=out_channels)
+            )
+            fc_layers.append(nn.ReLU(inplace=True))
+        fc_layers.append(nn.Linear(in_features=fc_features[-1], out_features=1))
+        self.fc = nn.Sequential(*fc_layers)
 
-    def forward(self, x: torch.Tensor):
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
         x = self.conv(x)
         x = torch.flatten(x, 1)
-        x = self.fully_connected(x)
+        x = self.fc(x)
         return x
 
+    def get_output_shape(self, input_shape: Tuple[int]) -> Tuple[int]:
+        """
+        Get the output shape of the discriminator with respect to an input shape.
+
+        Parameters
+        ----------
+        input_shape: Tuple[int]
+            the shape for which the output is computed which needs to have at least 4 dimension
+
+        Returns
+        -------
+        Tuple[int]
+            the according output shape
+        """
+        assert len(input_shape) > 3, "Input must contain at least 4 dimensions"
+        return (*input_shape[:-4], 1)
 
-class PatchDiscriminator(nn.Module):
-    def __init__(self, in_channels: int = 2):
+
+class PatchDiscriminator(nn.Sequential):
+    """
+    Patch discriminator (PatchGAN) used by Isola et al. 2017 for their image-to-image
+    translation GAN.
+    """
+
+    def __init__(self, in_channels: int = 2, features: List[int] = [64, 128, 256, 512]):
+        """
+        Create a new patch discriminator.
+
+        Parameters
+        ----------
+        in_channels: int
+            the number of channels of inputs for the model
+        features: list of int
+            the number of features in each convolutional layer of the model
+            note that each feature introduces a new layer that halves the input size
+        """
         super().__init__()
+        self.features = features
+        self.in_channels = in_channels
 
-        self.conv = nn.Sequential(
-            nn.Conv3d(
-                in_channels=in_channels,
-                out_channels=64,
-                kernel_size=4,
-                stride=2,
-                padding=1,
-            ),
-            nn.LeakyReLU(negative_slope=0.2, inplace=True),
-            nn.Conv3d(
-                in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1
-            ),
-            nn.BatchNorm3d(num_features=128),
-            nn.LeakyReLU(negative_slope=0.2, inplace=True),
-            nn.Conv3d(
-                in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1
-            ),
-            nn.BatchNorm3d(num_features=256),
-            nn.LeakyReLU(negative_slope=0.2, inplace=True),
-            nn.Conv3d(
-                in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1
-            ),
+        self.append(self.create_conv(in_channels, features[0]))
+        self.append(nn.LeakyReLU(inplace=True, negative_slope=0.2))
+        for in_channels, out_channels in zip(features[:-1], features[1:]):
+            self.append(self.create_conv(in_channels, out_channels))
+            self.append(nn.BatchNorm3d(num_features=out_channels))
+            self.append(nn.LeakyReLU(inplace=True, negative_slope=0.2))
+        self.append(self.create_conv(features[-1], 1))
+
+    def create_conv(self, in_channels: int, out_channels: int) -> nn.Conv3d:
+        """
+        Create a typical convolution used by the patch discriminator (kernel_size=4, stride=2, padding=1).
+
+        Parameters
+        ----------
+        in_channels: int
+            the number of input channels for the convolution
+        out_channels: int
+            the number of output channels (number of filters) of the convolution
+
+        Returns
+        -------
+        nn.Conv3d
+        """
+        return nn.Conv3d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=4,
+            stride=2,
+            padding=1,
         )
 
-    def forward(self, x: torch.Tensor):
-        return self.conv(x)
+    def get_output_shape(self, input_shape: Tuple[int]) -> Tuple[int]:
+        """
+        Get the output shape of the discriminator with respect to an input shape.
+
+        Parameters
+        ----------
+        input_shape: Tuple[int]
+            the shape for which the output is computed which needs to have at least 4 dimension
+
+        Returns
+        -------
+        Tuple[int]
+            the according output shape
+        """
+        assert len(input_shape) > 3, "Input must contain at least 4 dimensions"
+
+        n_convs = len(self.features) + 1  # each convolution halves the input dimensions
+        return (
+            *input_shape[:-4],
+            1,
+            *map(lambda x: x // (n_convs**2), input_shape[-3:]),
+        )
 
 
 if __name__ == "__main__":
-    batch_size = 4
-    input_size = (32, 64, 64)
-    in_channels = 2
+    import argparse
 
-    net = Discriminator(in_channels=in_channels, input_size=input_size)
-    print(net)
+    parser = argparse.ArgumentParser(
+        description="print and test different discriminators",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--type",
+        choices=["class", "patch"],
+        default="class",
+        help="test either a classification or a PatchGAN discriminator",
+    )
+    parser.add_argument(
+        "--conv_features",
+        type=int,
+        nargs="+",
+        default=[32, 64, 128],
+        help="the number of features for each convolutional layer",
+    )
+    parser.add_argument(
+        "--fc_features",
+        type=int,
+        nargs="+",
+        default=[512, 128],
+        help="the number of features for each fully connected layer",
+    )
+    parser.add_argument(
+        "--input_size",
+        type=int,
+        nargs=3,
+        default=[32, 32, 32],
+        help="shape of inputs for the discriminator",
+    )
+    parser.add_argument(
+        "--in_channels", type=int, default=2, help="number of channels for inputs"
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=8,
+        help="batch size of inputs for a test computation",
+    )
+    args = parser.parse_args()
 
-    if type(input_size) is int:
-        _inputs = torch.rand(
-            (batch_size, in_channels, input_size, input_size, input_size)
+    if args.type == "class":
+        net = Discriminator(
+            in_channels=args.in_channels,
+            input_size=args.input_size,
+            conv_features=args.conv_features,
+            fc_features=args.fc_features,
         )
-    else:
-        _inputs = torch.rand((batch_size, in_channels, *input_size))
+    elif args.type == "patch":
+        net = PatchDiscriminator(
+            in_channels=args.in_channels, features=args.conv_features
+        )
+    print(net)
+
+    _inputs = torch.rand((args.batch_size, args.in_channels, *args.input_size))
     _outputs = net(_inputs)
 
     _targets = torch.full(_outputs.shape, 1.0)
@@ -143,3 +272,4 @@ if __name__ == "__main__":
     print(loss.item())
 
     print(f"Transform {_inputs.shape} to {_outputs.shape}")
+    assert _outputs.shape == net.get_output_shape(_inputs.shape)
-- 
GitLab