From 7fd7c7b2c81c1f814dac7c44bdd71873101e1e22 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 6 Jan 2023 11:57:50 +0100 Subject: [PATCH] fix bug in initialization of discriminator --- mu_map/models/discriminator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index dd90d35..e1fe365 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -85,7 +85,7 @@ class Discriminator(nn.Module): self.conv = nn.Sequential(*conv_layers) # input is halved by each convolutional layer - self.fc_input_size = map(lambda x: x // 2 ** (len(conv_features)), input_size) + self.fc_input_size = map(lambda x: x // 2 ** (len(conv_features)), self.input_size) self.fc_input_size = reduce(lambda x, y: x * y, self.fc_input_size) fc_layers = [] @@ -249,6 +249,7 @@ if __name__ == "__main__": help="batch size of inputs for a test computation", ) args = parser.parse_args() + args.input_size = tuple(args.input_size) if args.type == "class": net = Discriminator( -- GitLab