diff --git a/mu_map/models/discriminator.py b/mu_map/models/discriminator.py index 5f1a6bc8e8b6e66124b0355fe269a89eb0165234..05de5caee053aa089ded8f1551e3b8c146814b23 100644 --- a/mu_map/models/discriminator.py +++ b/mu_map/models/discriminator.py @@ -60,6 +60,7 @@ class Discriminator(nn.Module): nn.Linear(in_features=512, out_features=128), nn.ReLU(inplace=True), nn.Linear(in_features=128, out_features=1), + nn.Sigmoid(), ) def forward(self, x: torch.Tensor): @@ -75,7 +76,12 @@ if __name__ == "__main__": net = Discriminator(input_size=input_size) print(net) - _inputs = torch.rand((1, 1, input_size, input_size, input_size)) + _inputs = torch.rand((4, 1, input_size, input_size, input_size)) _outputs = net(_inputs) + _targets = torch.full((4, 1), 1.0) + criterion = torch.nn.MSELoss() + loss = criterion(_outputs, _targets) + print(loss.item()) + print(f"Transform {_inputs.shape} to {_outputs.shape}")