diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf4898dde7d64085f47ad1ced6616692dd5ab9e2
--- /dev/null
+++ b/mu_map/models/discriminator.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn as nn
+
+
+class Conv(nn.Sequential):
+    def __init__(self, in_channels, out_channels):
+        super().__init__()
+
+        self.append(
+            nn.Conv3d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=3,
+                stride=1,
+                padding="same",
+            )
+        )
+        self.append(nn.BatchNorm3d(num_features=out_channels))
+        self.append(nn.ReLU(inplace=True))
+
+
+class Discriminator(nn.Module):
+    def __init__(self, in_channels=1):
+        super().__init__()
+
+        self.conv = nn.Sequential(
+            Conv(in_channels=in_channels, out_channels=32),
+            nn.MaxPool3d(kernel_size=2, stride=2),
+            Conv(in_channels=32, out_channels=64),
+            nn.MaxPool3d(kernel_size=2, stride=2),
+            Conv(in_channels=64, out_channels=128),
+            nn.MaxPool3d(kernel_size=2, stride=2),
+        )
+        self.fully_connected = nn.Sequential(
+            nn.Linear(in_features=128 * 2 ** 3, out_features=512),
+            nn.ReLU(inplace=True),
+            nn.Linear(in_features=512, out_features=128),
+            nn.ReLU(inplace=True),
+            nn.Linear(in_features=128, out_features=1),
+        )
+
+    def forward(self, x):
+        x = self.conv(x)
+        x = torch.flatten(x, 1)
+        x = self.fully_connected(x)
+        return x
+
+
+if __name__ == "__main__":
+    net = Discriminator()
+    print(net)
+
+    _inputs = torch.rand((1, 1, 16, 16, 16))
+    _outputs = net(_inputs)
+
+    print(f"Transform {_inputs.shape} to {_outputs.shape}")