From 9b8cb70956460233437705623f4e5e77d406d684 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 16 Aug 2022 16:37:17 +0200
Subject: [PATCH] add todos

---
 mu_map/models/discriminator.py | 18 ++++++------------
 1 file changed, 6 insertions(+), 12 deletions(-)

diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py
index cf4898d..68dbafa 100644
--- a/mu_map/models/discriminator.py
+++ b/mu_map/models/discriminator.py
@@ -1,27 +1,21 @@
 import torch
 import torch.nn as nn
 
-
 class Conv(nn.Sequential):
+
     def __init__(self, in_channels, out_channels):
         super().__init__()
 
-        self.append(
-            nn.Conv3d(
-                in_channels=in_channels,
-                out_channels=out_channels,
-                kernel_size=3,
-                stride=1,
-                padding="same",
-            )
-        )
+        self.append(nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same"))
         self.append(nn.BatchNorm3d(num_features=out_channels))
         self.append(nn.ReLU(inplace=True))
 
-
 class Discriminator(nn.Module):
+
     def __init__(self, in_channels=1):
         super().__init__()
+        #TODO: make fully connected layer dependent on input shape
+        #TODO: write doc
 
         self.conv = nn.Sequential(
             Conv(in_channels=in_channels, out_channels=32),
@@ -38,7 +32,7 @@ class Discriminator(nn.Module):
             nn.ReLU(inplace=True),
             nn.Linear(in_features=128, out_features=1),
         )
-
+    
     def forward(self, x):
         x = self.conv(x)
         x = torch.flatten(x, 1)
-- 
GitLab