From 06a38be4a276dcd874d74ba6036b3bfaa8f8f8a2 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Wed, 5 Oct 2022 15:56:10 +0200 Subject: [PATCH] update discriminator to use a sigmnoid in the end --- mu_map/models/discriminator.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index 5f1a6bc..05de5ca 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -60,6 +60,7 @@ class Discriminator(nn.Module): nn.Linear(in_features=512, out_features=128), nn.ReLU(inplace=True), nn.Linear(in_features=128, out_features=1), + nn.Sigmoid(), ) def forward(self, x: torch.Tensor): @@ -75,7 +76,12 @@ if __name__ == "__main__": net = Discriminator(input_size=input_size) print(net) - _inputs = torch.rand((1, 1, input_size, input_size, input_size)) + _inputs = torch.rand((4, 1, input_size, input_size, input_size)) _outputs = net(_inputs) + _targets = torch.full((4, 1), 1.0) + criterion = torch.nn.MSELoss() + loss = criterion(_outputs, _targets) + print(loss.item()) + print(f"Transform {_inputs.shape} to {_outputs.shape}") -- GitLab