import torch
import torch.nn as nn


class Conv(nn.Sequential):
    """
    A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
    """

    def __init__(self, in_channels: int, out_channels: int):
        """
        Create a convolutional layer with batch normalization and a ReLU activation function.

        :param in_channels: number of channels receives as input
        :param out_channels: number of filters and consequently channels in the output
        """
        super().__init__()

        self.append(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding="same",
            )
        )
        self.append(nn.BatchNorm3d(num_features=out_channels))
        self.append(nn.ReLU(inplace=True))


class Discriminator(nn.Module):
    """
    Create the discriminator as specified by Shi et al. (2020).
    If consists of three convolutional layers with max pooling, followed by three fully connected layers.
    """

    def __init__(self, in_channels: int = 1, input_size: int = 16):
        """
        Create the discriminator.

        :param in_channels: number channels received as an input
        :param input_size: shape of the inputs images which is required to compute the number of features in the first fully connected layer
        """
        super().__init__()
        # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3)
        fc_input_size = (input_size // 2 ** 3) ** 3

        self.conv = nn.Sequential(
            Conv(in_channels=in_channels, out_channels=32),
            nn.MaxPool3d(kernel_size=2, stride=2),
            Conv(in_channels=32, out_channels=64),
            nn.MaxPool3d(kernel_size=2, stride=2),
            Conv(in_channels=64, out_channels=128),
            nn.MaxPool3d(kernel_size=2, stride=2),
        )
        self.fully_connected = nn.Sequential(
            nn.Linear(in_features=128 * fc_input_size, out_features=512),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=128, out_features=1),
            # nn.Sigmoid(),
        )

    def forward(self, x: torch.Tensor):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.fully_connected(x)
        return x


class PatchDiscriminator(nn.Module):

    def __init__(self, in_channels: int = 2, input_size: int = 32):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(num_features=128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(num_features=256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            nn.Conv3d(in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, x: torch.Tensor):
        return self.conv(x)

if __name__ == "__main__":
    input_size = 32

    # net = Discriminator(input_size=input_size)
    net = PatchDiscriminator(input_size=input_size)
    print(net)

    _inputs = torch.rand((4, 2, input_size, input_size, input_size))
    _outputs = net(_inputs)

    _targets = torch.full(_outputs.shape, 1.0)
    criterion = torch.nn.MSELoss()
    loss = criterion(_outputs, _targets)
    print(loss.item())

    print(f"Transform {_inputs.shape} to {_outputs.shape}")