diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4898dde7d64085f47ad1ced6616692dd5ab9e2 --- /dev/null +++ b/mu_map/models/discriminator.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn + + +class Conv(nn.Sequential): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.append( + nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding="same", + ) + ) + self.append(nn.BatchNorm3d(num_features=out_channels)) + self.append(nn.ReLU(inplace=True)) + + +class Discriminator(nn.Module): + def __init__(self, in_channels=1): + super().__init__() + + self.conv = nn.Sequential( + Conv(in_channels=in_channels, out_channels=32), + nn.MaxPool3d(kernel_size=2, stride=2), + Conv(in_channels=32, out_channels=64), + nn.MaxPool3d(kernel_size=2, stride=2), + Conv(in_channels=64, out_channels=128), + nn.MaxPool3d(kernel_size=2, stride=2), + ) + self.fully_connected = nn.Sequential( + nn.Linear(in_features=128 * 2 ** 3, out_features=512), + nn.ReLU(inplace=True), + nn.Linear(in_features=512, out_features=128), + nn.ReLU(inplace=True), + nn.Linear(in_features=128, out_features=1), + ) + + def forward(self, x): + x = self.conv(x) + x = torch.flatten(x, 1) + x = self.fully_connected(x) + return x + + +if __name__ == "__main__": + net = Discriminator() + print(net) + + _inputs = torch.rand((1, 1, 16, 16, 16)) + _outputs = net(_inputs) + + print(f"Transform {_inputs.shape} to {_outputs.shape}")