import torch import torch.nn as nn class Conv(nn.Sequential): """ A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function. """ def __init__(self, in_channels: int, out_channels: int): """ Create a convolutional layer with batch normalization and a ReLU activation function. :param in_channels: number of channels receives as input :param out_channels: number of filters and consequently channels in the output """ super().__init__() self.append( nn.Conv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same", ) ) self.append(nn.BatchNorm3d(num_features=out_channels)) self.append(nn.ReLU(inplace=True)) class Discriminator(nn.Module): """ Create the discriminator as specified by Shi et al. (2020). If consists of three convolutional layers with max pooling, followed by three fully connected layers. """ def __init__(self, in_channels: int = 1, input_size: int = 16): """ Create the discriminator. :param in_channels: number channels received as an input :param input_size: shape of the inputs images which is required to compute the number of features in the first fully connected layer """ super().__init__() # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3) fc_input_size = (input_size // 2 ** 3) ** 3 self.conv = nn.Sequential( Conv(in_channels=in_channels, out_channels=32), nn.MaxPool3d(kernel_size=2, stride=2), Conv(in_channels=32, out_channels=64), nn.MaxPool3d(kernel_size=2, stride=2), Conv(in_channels=64, out_channels=128), nn.MaxPool3d(kernel_size=2, stride=2), ) self.fully_connected = nn.Sequential( nn.Linear(in_features=128 * fc_input_size, out_features=512), nn.ReLU(inplace=True), nn.Linear(in_features=512, out_features=128), nn.ReLU(inplace=True), nn.Linear(in_features=128, out_features=1), # nn.Sigmoid(), ) def forward(self, x: torch.Tensor): x = self.conv(x) x = torch.flatten(x, 1) x = self.fully_connected(x) return x class PatchDiscriminator(nn.Module): def __init__(self, in_channels: int = 2, input_size: int = 32): super().__init__() self.conv = nn.Sequential( nn.Conv3d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), nn.BatchNorm3d(num_features=128), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv3d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), nn.BatchNorm3d(num_features=256), nn.LeakyReLU(negative_slope=0.2, inplace=True), nn.Conv3d(in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1), ) def forward(self, x: torch.Tensor): return self.conv(x) if __name__ == "__main__": input_size = 32 # net = Discriminator(input_size=input_size) net = PatchDiscriminator(input_size=input_size) print(net) _inputs = torch.rand((4, 2, input_size, input_size, input_size)) _outputs = net(_inputs) _targets = torch.full(_outputs.shape, 1.0) criterion = torch.nn.MSELoss() loss = criterion(_outputs, _targets) print(loss.item()) print(f"Transform {_inputs.shape} to {_outputs.shape}")