From d60df9f1893ded1d52349edb03ce948f959abe69 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 16 Aug 2022 16:36:42 +0200
Subject: [PATCH] implement discrimnator

---
 mu_map/models/discriminator.py | 56 ++++++++++++++++++++++++++++++++++
 1 file changed, 56 insertions(+)
 create mode 100644 mu_map/models/discriminator.py

diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
new file mode 100644
index 0000000..cf4898d
--- /dev/null
+++ b/mu_map/models/discriminator.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn as nn
+
+
+class Conv(nn.Sequential):
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+
+        self.append(
+            nn.Conv3d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=3,
+                stride=1,
+                padding="same",
+            )
+        )
+        self.append(nn.BatchNorm3d(num_features=out_channels))
+        self.append(nn.ReLU(inplace=True))
+
+
+class Discriminator(nn.Module):
+    def __init__(self, in_channels=1):
+        super().__init__()
+
+        self.conv = nn.Sequential(
+            Conv(in_channels=in_channels, out_channels=32),
+            nn.MaxPool3d(kernel_size=2, stride=2),
+            Conv(in_channels=32, out_channels=64),
+            nn.MaxPool3d(kernel_size=2, stride=2),
+            Conv(in_channels=64, out_channels=128),
+            nn.MaxPool3d(kernel_size=2, stride=2),
+        )
+        self.fully_connected = nn.Sequential(
+            nn.Linear(in_features=128 * 2 ** 3, out_features=512),
+            nn.ReLU(inplace=True),
+            nn.Linear(in_features=512, out_features=128),
+            nn.ReLU(inplace=True),
+            nn.Linear(in_features=128, out_features=1),
+        )
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = torch.flatten(x, 1)
+        x = self.fully_connected(x)
+        return x
+
+
+if __name__ == "__main__":
+    net = Discriminator()
+    print(net)
+
+    _inputs = torch.rand((1, 1, 16, 16, 16))
+    _outputs = net(_inputs)
+
+    print(f"Transform {_inputs.shape} to {_outputs.shape}")
-- 
GitLab