Skip to content
Snippets Groups Projects
Commit df35525d authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add type annotations to discriminator

parent d7e441c1
No related branches found
No related tags found
No related merge requests found
......@@ -7,7 +7,7 @@ class Conv(nn.Sequential):
A wrapper around a 3D convolutional layer that also contains batch normalization and a ReLU activation function.
"""
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels: int, out_channels: int):
"""
Create a convolutional layer with batch normalization and a ReLU activation function.
......@@ -35,7 +35,7 @@ class Discriminator(nn.Module):
If consists of three convolutional layers with max pooling, followed by three fully connected layers.
"""
def __init__(self, in_channels=1, input_size=16):
def __init__(self, in_channels: int = 1, input_size: int = 16):
"""
Create the discriminator.
......@@ -62,7 +62,7 @@ class Discriminator(nn.Module):
nn.Linear(in_features=128, out_features=1),
)
def forward(self, x):
def forward(self, x: torch.Tensor):
x = self.conv(x)
x = torch.flatten(x, 1)
x = self.fully_connected(x)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment