From 19046346765d9f38b37f766fdfdabcac77e30af1 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Thu, 5 Jan 2023 10:35:13 +0100 Subject: [PATCH] formatting --- mu_map/models/discriminator.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index a7e77eb..0409463 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -57,12 +57,14 @@ class Discriminator(nn.Module): super().__init__() # input is halved three time (// 2 ** 3) and we deal with 3D inputs (**3) if type(input_size) is int: - fc_input_size = (input_size // 2 ** 3) ** 3 + fc_input_size = (input_size // 2**3) ** 3 elif type(input_size) is tuple: - fc_input_size = map(lambda x: x // 2 ** 3, input_size) + fc_input_size = map(lambda x: x // 2**3, input_size) fc_input_size = reduce(lambda x, y: x * y, fc_input_size) else: - raise ValueError(f"Cannot deal with input size {input_size} of type {type(input_size)}") + raise ValueError( + f"Cannot deal with input size {input_size} of type {type(input_size)}" + ) self.conv = nn.Sequential( Conv(in_channels=in_channels, out_channels=32), @@ -88,25 +90,37 @@ class Discriminator(nn.Module): class PatchDiscriminator(nn.Module): - def __init__(self, in_channels: int = 2): super().__init__() self.conv = nn.Sequential( - nn.Conv3d(in_channels=in_channels, out_channels=64, kernel_size=4, stride=2, padding=1), + nn.Conv3d( + in_channels=in_channels, + out_channels=64, + kernel_size=4, + stride=2, + padding=1, + ), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), + nn.Conv3d( + in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1 + ), nn.BatchNorm3d(num_features=128), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv3d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1), + nn.Conv3d( + in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1 + ), nn.BatchNorm3d(num_features=256), nn.LeakyReLU(negative_slope=0.2, inplace=True), - nn.Conv3d(in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1), + nn.Conv3d( + in_channels=256, out_channels=1, kernel_size=4, stride=2, padding=1 + ), ) def forward(self, x: torch.Tensor): return self.conv(x) + if __name__ == "__main__": batch_size = 4 input_size = (32, 64, 64) @@ -116,7 +130,9 @@ if __name__ == "__main__": print(net) if type(input_size) is int: - _inputs = torch.rand((batch_size, in_channels, input_size, input_size, input_size)) + _inputs = torch.rand( + (batch_size, in_channels, input_size, input_size, input_size) + ) else: _inputs = torch.rand((batch_size, in_channels, *input_size)) _outputs = net(_inputs) -- GitLab