import torch
import torch.nn as nn


class Conv(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.append(
            nn.Conv3d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=1,
                padding="same",
            )
        )
        self.append(nn.BatchNorm3d(num_features=out_channels))
        self.append(nn.ReLU(inplace=True))


class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()

        self.conv = nn.Sequential(
            Conv(in_channels=in_channels, out_channels=32),
            nn.MaxPool3d(kernel_size=2, stride=2),
            Conv(in_channels=32, out_channels=64),
            nn.MaxPool3d(kernel_size=2, stride=2),
            Conv(in_channels=64, out_channels=128),
            nn.MaxPool3d(kernel_size=2, stride=2),
        )
        self.fully_connected = nn.Sequential(
            nn.Linear(in_features=128 * 2 ** 3, out_features=512),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=128, out_features=1),
        )

    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.fully_connected(x)
        return x


if __name__ == "__main__":
    net = Discriminator()
    print(net)

    _inputs = torch.rand((1, 1, 16, 16, 16))
    _outputs = net(_inputs)

    print(f"Transform {_inputs.shape} to {_outputs.shape}")