import torch
import torch.nn as nn

class Conv(nn.Sequential):

    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.append(nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same"))
        self.append(nn.BatchNorm3d(num_features=out_channels))
        self.append(nn.ReLU(inplace=True))

class Discriminator(nn.Module):

    def __init__(self, in_channels=1):
        super().__init__()
        #TODO: make fully connected layer dependent on input shape
        #TODO: write doc

        self.conv = nn.Sequential(
            Conv(in_channels=in_channels, out_channels=32),
            nn.MaxPool3d(kernel_size=2, stride=2),
            Conv(in_channels=32, out_channels=64),
            nn.MaxPool3d(kernel_size=2, stride=2),
            Conv(in_channels=64, out_channels=128),
            nn.MaxPool3d(kernel_size=2, stride=2),
        )
        self.fully_connected = nn.Sequential(
            nn.Linear(in_features=128 * 2 ** 3, out_features=512),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=128, out_features=1),
        )
    
    def forward(self, x):
        x = self.conv(x)
        x = torch.flatten(x, 1)
        x = self.fully_connected(x)
        return x


if __name__ == "__main__":
    net = Discriminator()
    print(net)

    _inputs = torch.rand((1, 1, 16, 16, 16))
    _outputs = net(_inputs)

    print(f"Transform {_inputs.shape} to {_outputs.shape}")