From 06a38be4a276dcd874d74ba6036b3bfaa8f8f8a2 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Wed, 5 Oct 2022 15:56:10 +0200
Subject: [PATCH] update discriminator to use a sigmnoid in the end

---
 mu_map/models/discriminator.py | 8 +++++++-
 1 file changed, 7 insertions(+), 1 deletion(-)

diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index 5f1a6bc..05de5ca 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -60,6 +60,7 @@ class Discriminator(nn.Module):
             nn.Linear(in_features=512, out_features=128),
             nn.ReLU(inplace=True),
             nn.Linear(in_features=128, out_features=1),
+            nn.Sigmoid(),
         )
 
     def forward(self, x: torch.Tensor):
@@ -75,7 +76,12 @@ if __name__ == "__main__":
     net = Discriminator(input_size=input_size)
     print(net)
 
-    _inputs = torch.rand((1, 1, input_size, input_size, input_size))
+    _inputs = torch.rand((4, 1, input_size, input_size, input_size))
     _outputs = net(_inputs)
 
+    _targets = torch.full((4, 1), 1.0)
+    criterion = torch.nn.MSELoss()
+    loss = criterion(_outputs, _targets)
+    print(loss.item())
+
     print(f"Transform {_inputs.shape} to {_outputs.shape}")
-- 
GitLab