Skip to content
Snippets Groups Projects
discriminator.py 3.76 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    import torch
    import torch.nn as nn
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    class Conv(nn.Sequential):
    
        """
        A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
        """
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    
        def __init__(self, in_channels: int, out_channels: int):
    
            """
            Create a convolutional layer with batch normalization and a ReLU activation function.
    
            :param in_channels: number of channels receives as input
            :param out_channels: number of filters and consequently channels in the output
            """
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            super().__init__()
    
    
            self.append(
                nn.Conv3d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=3,
                    stride=1,
                    padding="same",
                )
            )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.append(nn.BatchNorm3d(num_features=out_channels))
            self.append(nn.ReLU(inplace=True))
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    class Discriminator(nn.Module):
    
        """
        Create the discriminator as specified by Shi et al. (2020).
        If consists of three convolutional layers with max pooling, followed by three fully connected layers.
        """
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    
        def __init__(self, in_channels: int = 1, input_size: int = 16):
    
            """
            Create the discriminator.
    
            :param in_channels: number channels received as an input
            :param input_size: shape of the inputs images which is required to compute the number of features in the first fully connected layer
            """
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            super().__init__()
    
            # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3)
            fc_input_size = (input_size // 2 ** 3) ** 3
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
            self.conv = nn.Sequential(
                Conv(in_channels=in_channels, out_channels=32),
                nn.MaxPool3d(kernel_size=2, stride=2),
                Conv(in_channels=32, out_channels=64),
                nn.MaxPool3d(kernel_size=2, stride=2),
                Conv(in_channels=64, out_channels=128),
                nn.MaxPool3d(kernel_size=2, stride=2),
            )
            self.fully_connected = nn.Sequential(
    
                nn.Linear(in_features=128 * fc_input_size, out_features=512),
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                nn.ReLU(inplace=True),
                nn.Linear(in_features=512, out_features=128),
                nn.ReLU(inplace=True),
                nn.Linear(in_features=128, out_features=1),
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            )
    
        def forward(self, x: torch.Tensor):
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            x = self.conv(x)
            x = torch.flatten(x, 1)
            x = self.fully_connected(x)
            return x
    
    
    
    class PatchDiscriminator(nn.Module):
    
        def __init__(self, in_channels: int = 2, input_size: int = 32):
            super().__init__()
    
            self.conv = nn.Sequential(
                nn.Conv3d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm3d(num_features=128),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv3d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
                nn.BatchNorm3d(num_features=256),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
                nn.Conv3d(in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1),
            )
    
        def forward(self, x: torch.Tensor):
            return self.conv(x)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    if __name__ == "__main__":
    
        # net = Discriminator(input_size=input_size)
        net = PatchDiscriminator(input_size=input_size)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        print(net)
    
    
        _inputs = torch.rand((4, 2, input_size, input_size, input_size))
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        _outputs = net(_inputs)
    
    
        _targets = torch.full(_outputs.shape, 1.0)
    
        criterion = torch.nn.MSELoss()
        loss = criterion(_outputs, _targets)
        print(loss.item())
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        print(f"Transform {_inputs.shape} to {_outputs.shape}")