From 364079bab9a898f0f95c117c1219c8675d4dd517 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 6 Jan 2023 11:13:38 +0100 Subject: [PATCH] re-structure discriminator implementation --- mu_map/models/discriminator.py | 266 ++++++++++++++++++++++++--------- 1 file changed, 198 insertions(+), 68 deletions(-) diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index 0409463..dd90d35 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -1,5 +1,5 @@ from functools import reduce -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -7,17 +7,21 @@ import torch.nn as nn class Conv(nn.Sequential): """ - A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function. + A wrapper around a 3D convolutional layer that also contains batch normalization, a ReLU activation function and a max pooling. """ - def __init__(self, in_channels: int, out_channels: int): + def __init__( + self, + in_channels: int, + out_channels: int, + ): """ - Create a convolutional layer with batch normalization and a ReLU activation function. + Create a convolutional layer with batch normalization, a ReLU activation function and max pooling. Parameters ---------- in_channels: int - number of channels receives as input + number of channels received as input out_channels: int number of filters and consequently channels in the output """ @@ -34,6 +38,7 @@ class Conv(nn.Sequential): ) self.append(nn.BatchNorm3d(num_features=out_channels)) self.append(nn.ReLU(inplace=True)) + self.append(nn.MaxPool3d(kernel_size=2, stride=2)) class Discriminator(nn.Module): @@ -42,7 +47,13 @@ class Discriminator(nn.Module): It consists of three convolutional layers with max pooling, followed by three fully connected layers. """ - def __init__(self, in_channels: int = 1, input_size: Union[int, Tuple[int]] = 16): + def __init__( + self, + in_channels: int, + input_size: Union[int, Tuple[int, int, int]], + conv_features: List[int] = [32, 64, 128], + fc_features: List[int] = [512, 128], + ): """ Create the discriminator. @@ -53,88 +64,206 @@ class Discriminator(nn.Module): input_size: int or tuple of int shape of the inputs either as an int (if equal in all dimensions) or as a tuple this is required to compute the number of features for the first fully connected layer + conv_features: list of int + each number represents the number of filter in a convolutional layer + fc_features: list of int + each number represents the number of features in a fully connected layer + note that there will always be an additional fully connected layer between these features and the convolutions """ super().__init__() - # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3) - if type(input_size) is int: - fc_input_size = (input_size // 2**3) ** 3 - elif type(input_size) is tuple: - fc_input_size = map(lambda x: x // 2**3, input_size) - fc_input_size = reduce(lambda x, y: x * y, fc_input_size) - else: - raise ValueError( - f"Cannot deal with input size {input_size} of type {type(input_size)}" - ) - self.conv = nn.Sequential( - Conv(in_channels=in_channels, out_channels=32), - nn.MaxPool3d(kernel_size=2, stride=2), - Conv(in_channels=32, out_channels=64), - nn.MaxPool3d(kernel_size=2, stride=2), - Conv(in_channels=64, out_channels=128), - nn.MaxPool3d(kernel_size=2, stride=2), - ) - self.fully_connected = nn.Sequential( - nn.Linear(in_features=128 * fc_input_size, out_features=512), - nn.ReLU(inplace=True), - nn.Linear(in_features=512, out_features=128), - nn.ReLU(inplace=True), - nn.Linear(in_features=128, out_features=1), + self.in_channels = in_channels + self.input_size = input_size if type(input_size) is tuple else (input_size,) * 3 + self.conv_features = conv_features + self.fc_features = fc_features + + conv_layers = [] + for in_channels, out_channels in zip( + [in_channels, *conv_features[:-1]], conv_features + ): + conv_layers.append(Conv(in_channels, out_channels)) + self.conv = nn.Sequential(*conv_layers) + + # input is halved by each convolutional layer + self.fc_input_size = map(lambda x: x // 2 ** (len(conv_features)), input_size) + self.fc_input_size = reduce(lambda x, y: x * y, self.fc_input_size) + + fc_layers = [] + fc_layers.append( + nn.Linear( + in_features=self.fc_input_size * conv_features[-1], + out_features=fc_features[0], + ) ) + fc_layers.append(nn.ReLU(inplace=True)) + for in_channels, out_channels in zip([*fc_features[:-1]], fc_features[1:]): + fc_layers.append( + nn.Linear(in_features=in_channels, out_features=out_channels) + ) + fc_layers.append(nn.ReLU(inplace=True)) + fc_layers.append(nn.Linear(in_features=fc_features[-1], out_features=1)) + self.fc = nn.Sequential(*fc_layers) - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = torch.flatten(x, 1) - x = self.fully_connected(x) + x = self.fc(x) return x + def get_output_shape(self, input_shape: Tuple[int]) -> Tuple[int]: + """ + Get the output shape of the discriminator with respect to an input shape. + + Parameters + ---------- + input_shape: Tuple[int] + the shape for which the output is computed which needs to have at least 4 dimension + + Returns + ------- + Tuple[int] + the according output shape + """ + assert len(input_shape) > 3, "Input must contain at least 4 dimensions" + return (*input_shape[:-4], 1) -class PatchDiscriminator(nn.Module): - def __init__(self, in_channels: int = 2): + +class PatchDiscriminator(nn.Sequential): + """ + Patch discriminator (PatchGAN) used by Isola et al. 2017 for their image-to-image + translation GAN. + """ + + def __init__(self, in_channels: int = 2, features: List[int] = [64, 128, 256, 512]): + """ + Create a new patch discriminator. + + Parameters + ---------- + in_channels: int + the number of channels of inputs for the model + features: list of int + the number of features in each convolutional layer of the model + note that each feature introduces a new layer that halves the input size + """ super().__init__() + self.features = features + self.in_channels = in_channels - self.conv = nn.Sequential( - nn.Conv3d( - in_channels=in_channels, - out_channels=64, - kernel_size=4, - stride=2, - padding=1, - ), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv3d( - in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1 - ), - nn.BatchNorm3d(num_features=128), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv3d( - in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1 - ), - nn.BatchNorm3d(num_features=256), - nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv3d( - in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1 - ), + self.append(self.create_conv(in_channels, features[0])) + self.append(nn.LeakyReLU(inplace=True, negative_slope=0.2)) + for in_channels, out_channels in zip(features[:-1], features[1:]): + self.append(self.create_conv(in_channels, out_channels)) + self.append(nn.BatchNorm3d(num_features=out_channels)) + self.append(nn.LeakyReLU(inplace=True, negative_slope=0.2)) + self.append(self.create_conv(features[-1], 1)) + + def create_conv(self, in_channels: int, out_channels: int) -> nn.Conv3d: + """ + Create a typical convolution used by the patch discriminator (kernel_size=4, stride=2, padding=1). + + Parameters + ---------- + in_channels: int + the number of input channels for the convolution + out_channels: int + the number of output channels (number of filters) of the convolution + + Returns + ------- + nn.Conv3d + """ + return nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=4, + stride=2, + padding=1, ) - def forward(self, x: torch.Tensor): - return self.conv(x) + def get_output_shape(self, input_shape: Tuple[int]) -> Tuple[int]: + """ + Get the output shape of the discriminator with respect to an input shape. + + Parameters + ---------- + input_shape: Tuple[int] + the shape for which the output is computed which needs to have at least 4 dimension + + Returns + ------- + Tuple[int] + the according output shape + """ + assert len(input_shape) > 3, "Input must contain at least 4 dimensions" + + n_convs = len(self.features) + 1 # each convolution halves the input dimensions + return ( + *input_shape[:-4], + 1, + *map(lambda x: x // (n_convs**2), input_shape[-3:]), + ) if __name__ == "__main__": - batch_size = 4 - input_size = (32, 64, 64) - in_channels = 2 + import argparse - net = Discriminator(in_channels=in_channels, input_size=input_size) - print(net) + parser = argparse.ArgumentParser( + description="print and test different discriminators", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--type", + choices=["class", "patch"], + default="class", + help="test either a classification or a PatchGAN discriminator", + ) + parser.add_argument( + "--conv_features", + type=int, + nargs="+", + default=[32, 64, 128], + help="the number of features for each convolutional layer", + ) + parser.add_argument( + "--fc_features", + type=int, + nargs="+", + default=[512, 128], + help="the number of features for each fully connected layer", + ) + parser.add_argument( + "--input_size", + type=int, + nargs=3, + default=[32, 32, 32], + help="shape of inputs for the discriminator", + ) + parser.add_argument( + "--in_channels", type=int, default=2, help="number of channels for inputs" + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="batch size of inputs for a test computation", + ) + args = parser.parse_args() - if type(input_size) is int: - _inputs = torch.rand( - (batch_size, in_channels, input_size, input_size, input_size) + if args.type == "class": + net = Discriminator( + in_channels=args.in_channels, + input_size=args.input_size, + conv_features=args.conv_features, + fc_features=args.fc_features, ) - else: - _inputs = torch.rand((batch_size, in_channels, *input_size)) + elif args.type == "patch": + net = PatchDiscriminator( + in_channels=args.in_channels, features=args.conv_features + ) + print(net) + + _inputs = torch.rand((args.batch_size, args.in_channels, *args.input_size)) _outputs = net(_inputs) _targets = torch.full(_outputs.shape, 1.0) @@ -143,3 +272,4 @@ if __name__ == "__main__": print(loss.item()) print(f"Transform {_inputs.shape} to {_outputs.shape}") + assert _outputs.shape == net.get_output_shape(_inputs.shape) -- GitLab