diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index a7e77eb6e82d4bbedc6e23b4422d403cde029c0f..0409463f09d862d07ed24f734954b8de1e44c744 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -57,12 +57,14 @@ class Discriminator(nn.Module):
         super().__init__()
         # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3)
         if type(input_size) is int:
-            fc_input_size = (input_size // 2 ** 3) ** 3
+            fc_input_size = (input_size // 2**3) ** 3
         elif type(input_size) is tuple:
-            fc_input_size = map(lambda x: x // 2 ** 3, input_size)
+            fc_input_size = map(lambda x: x // 2**3, input_size)
             fc_input_size = reduce(lambda x, y: x * y, fc_input_size)
         else:
-            raise ValueError(f"Cannot deal with input size {input_size} of type {type(input_size)}")
+            raise ValueError(
+                f"Cannot deal with input size {input_size} of type {type(input_size)}"
+            )
 
         self.conv = nn.Sequential(
             Conv(in_channels=in_channels, out_channels=32),
@@ -88,25 +90,37 @@ class Discriminator(nn.Module):
 
 
 class PatchDiscriminator(nn.Module):
-
     def __init__(self, in_channels: int = 2):
         super().__init__()
 
         self.conv = nn.Sequential(
-            nn.Conv3d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1),
+            nn.Conv3d(
+                in_channels=in_channels,
+                out_channels=64,
+                kernel_size=4,
+                stride=2,
+                padding=1,
+            ),
             nn.LeakyReLU(negative_slope=0.2, inplace=True),
-            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
+            nn.Conv3d(
+                in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1
+            ),
             nn.BatchNorm3d(num_features=128),
             nn.LeakyReLU(negative_slope=0.2, inplace=True),
-            nn.Conv3d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
+            nn.Conv3d(
+                in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1
+            ),
             nn.BatchNorm3d(num_features=256),
             nn.LeakyReLU(negative_slope=0.2, inplace=True),
-            nn.Conv3d(in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1),
+            nn.Conv3d(
+                in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1
+            ),
         )
 
     def forward(self, x: torch.Tensor):
         return self.conv(x)
 
+
 if __name__ == "__main__":
     batch_size = 4
     input_size = (32, 64, 64)
@@ -116,7 +130,9 @@ if __name__ == "__main__":
     print(net)
 
     if type(input_size) is int:
-        _inputs = torch.rand((batch_size, in_channels, input_size, input_size, input_size))
+        _inputs = torch.rand(
+            (batch_size, in_channels, input_size, input_size, input_size)
+        )
     else:
         _inputs = torch.rand((batch_size, in_channels, *input_size))
     _outputs = net(_inputs)