From df35525d122eeb7c0a53154ce252edc2c819a2b1 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Wed, 17 Aug 2022 16:23:47 +0200
Subject: [PATCH] add type annotations to discriminator

---
 mu_map/models/discriminator.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index c57c94c..5f1a6bc 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -7,7 +7,7 @@ class Conv(nn.Sequential):
     A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
     """
 
-    def __init__(self, in_channels, out_channels):
+    def __init__(self, in_channels: int, out_channels: int):
         """
         Create a convolutional layer with batch normalization and a ReLU activation function.
 
@@ -35,7 +35,7 @@ class Discriminator(nn.Module):
     If consists of three convolutional layers with max pooling, followed by three fully connected layers.
     """
 
-    def __init__(self, in_channels=1, input_size=16):
+    def __init__(self, in_channels: int = 1, input_size: int = 16):
         """
         Create the discriminator.
 
@@ -62,7 +62,7 @@ class Discriminator(nn.Module):
             nn.Linear(in_features=128, out_features=1),
         )
 
-    def forward(self, x):
+    def forward(self, x: torch.Tensor):
         x = self.conv(x)
         x = torch.flatten(x, 1)
         x = self.fully_connected(x)
-- 
GitLab