diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index c57c94c7633371428ac6de6bbdc6b645213e97d1..5f1a6bc8e8b6e66124b0355fe269a89eb0165234 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -7,7 +7,7 @@ class Conv(nn.Sequential): A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function. """ - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels: int, out_channels: int): """ Create a convolutional layer with batch normalization and a ReLU activation function. @@ -35,7 +35,7 @@ class Discriminator(nn.Module): If consists of three convolutional layers with max pooling, followed by three fully connected layers. """ - def __init__(self, in_channels=1, input_size=16): + def __init__(self, in_channels: int = 1, input_size: int = 16): """ Create the discriminator. @@ -62,7 +62,7 @@ class Discriminator(nn.Module): nn.Linear(in_features=128, out_features=1), ) - def forward(self, x): + def forward(self, x: torch.Tensor): x = self.conv(x) x = torch.flatten(x, 1) x = self.fully_connected(x)