From df35525d122eeb7c0a53154ce252edc2c819a2b1 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Wed, 17 Aug 2022 16:23:47 +0200 Subject: [PATCH] add type annotations to discriminator --- mu_map/models/discriminator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index c57c94c..5f1a6bc 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -7,7 +7,7 @@ class Conv(nn.Sequential): A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function. """ - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels: int, out_channels: int): """ Create a convolutional layer with batch normalization and a ReLU activation function. @@ -35,7 +35,7 @@ class Discriminator(nn.Module): If consists of three convolutional layers with max pooling, followed by three fully connected layers. """ - def __init__(self, in_channels=1, input_size=16): + def __init__(self, in_channels: int = 1, input_size: int = 16): """ Create the discriminator. @@ -62,7 +62,7 @@ class Discriminator(nn.Module): nn.Linear(in_features=128, out_features=1), ) - def forward(self, x): + def forward(self, x: torch.Tensor): x = self.conv(x) x = torch.flatten(x, 1) x = self.fully_connected(x) -- GitLab