From 7fd7c7b2c81c1f814dac7c44bdd71873101e1e22 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 6 Jan 2023 11:57:50 +0100
Subject: [PATCH] fix bug in initialization of discriminator

---
 mu_map/models/discriminator.py | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index dd90d35..e1fe365 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -85,7 +85,7 @@ class Discriminator(nn.Module):
         self.conv = nn.Sequential(*conv_layers)
 
         # input is halved by each convolutional layer
-        self.fc_input_size = map(lambda x: x // 2 ** (len(conv_features)), input_size)
+        self.fc_input_size = map(lambda x: x // 2 ** (len(conv_features)), self.input_size)
         self.fc_input_size = reduce(lambda x, y: x * y, self.fc_input_size)
 
         fc_layers = []
@@ -249,6 +249,7 @@ if __name__ == "__main__":
         help="batch size of inputs for a test computation",
     )
     args = parser.parse_args()
+    args.input_size = tuple(args.input_size)
 
     if args.type == "class":
         net = Discriminator(
-- 
GitLab