From 9b8cb70956460233437705623f4e5e77d406d684 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 16 Aug 2022 16:37:17 +0200 Subject: [PATCH] add todos --- mu_map/models/discriminator.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index cf4898d..68dbafa 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -1,27 +1,21 @@ import torch import torch.nn as nn - class Conv(nn.Sequential): + def __init__(self, in_channels, out_channels): super().__init__() - self.append( - nn.Conv3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - stride=1, - padding="same", - ) - ) + self.append(nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding="same")) self.append(nn.BatchNorm3d(num_features=out_channels)) self.append(nn.ReLU(inplace=True)) - class Discriminator(nn.Module): + def __init__(self, in_channels=1): super().__init__() + #TODO: make fully connected layer dependent on input shape + #TODO: write doc self.conv = nn.Sequential( Conv(in_channels=in_channels, out_channels=32), @@ -38,7 +32,7 @@ class Discriminator(nn.Module): nn.ReLU(inplace=True), nn.Linear(in_features=128, out_features=1), ) - + def forward(self, x): x = self.conv(x) x = torch.flatten(x, 1) -- GitLab