diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index c57c94c7633371428ac6de6bbdc6b645213e97d1..5f1a6bc8e8b6e66124b0355fe269a89eb0165234 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -7,7 +7,7 @@ class Conv(nn.Sequential):
     A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
     """
 
-    def __init__(self, in_channels, out_channels):
+    def __init__(self, in_channels: int, out_channels: int):
         """
         Create a convolutional layer with batch normalization and a ReLU activation function.
 
@@ -35,7 +35,7 @@ class Discriminator(nn.Module):
     If consists of three convolutional layers with max pooling, followed by three fully connected layers.
     """
 
-    def __init__(self, in_channels=1, input_size=16):
+    def __init__(self, in_channels: int = 1, input_size: int = 16):
         """
         Create the discriminator.
 
@@ -62,7 +62,7 @@ class Discriminator(nn.Module):
             nn.Linear(in_features=128, out_features=1),
         )
 
-    def forward(self, x):
+    def forward(self, x: torch.Tensor):
         x = self.conv(x)
         x = torch.flatten(x, 1)
         x = self.fully_connected(x)