From d60df9f1893ded1d52349edb03ce948f959abe69 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 16 Aug 2022 16:36:42 +0200 Subject: [PATCH] implement discrimnator --- mu_map/models/discriminator.py | 56 ++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 mu_map/models/discriminator.py diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py new file mode 100644 index 0000000..cf4898d --- /dev/null +++ b/mu_map/models/discriminator.py @@ -0,0 +1,56 @@ +import torch +import torch.nn as nn + + +class Conv(nn.Sequential): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.append( + nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding="same", + ) + ) + self.append(nn.BatchNorm3d(num_features=out_channels)) + self.append(nn.ReLU(inplace=True)) + + +class Discriminator(nn.Module): + def __init__(self, in_channels=1): + super().__init__() + + self.conv = nn.Sequential( + Conv(in_channels=in_channels, out_channels=32), + nn.MaxPool3d(kernel_size=2, stride=2), + Conv(in_channels=32, out_channels=64), + nn.MaxPool3d(kernel_size=2, stride=2), + Conv(in_channels=64, out_channels=128), + nn.MaxPool3d(kernel_size=2, stride=2), + ) + self.fully_connected = nn.Sequential( + nn.Linear(in_features=128 * 2 ** 3, out_features=512), + nn.ReLU(inplace=True), + nn.Linear(in_features=512, out_features=128), + nn.ReLU(inplace=True), + nn.Linear(in_features=128, out_features=1), + ) + + def forward(self, x): + x = self.conv(x) + x = torch.flatten(x, 1) + x = self.fully_connected(x) + return x + + +if __name__ == "__main__": + net = Discriminator() + print(net) + + _inputs = torch.rand((1, 1, 16, 16, 16)) + _outputs = net(_inputs) + + print(f"Transform {_inputs.shape} to {_outputs.shape}") -- GitLab