Skip to content
Snippets Groups Projects
Commit 325fa268 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

adapt cgan training to work with patch discriminator

parent d37eb79a
No related branches found
No related tags found
No related merge requests found
......@@ -106,15 +106,14 @@ class cGANTraining(AbstractTraining):
# compute fake mu maps with generator
mu_maps_fake = self.generator(recons)
# note: the batch size may differ for the last batch which is why self.batch_size is not reliable
batch_size = recons.shape[0]
labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device)
labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device)
# prepare inputs for the discriminator
inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
# prepare labels/targets for the discriminator
labels_fake = torch.full(self.discriminator.get_output_shape(inputs_d_fake.shape), LABEL_FAKE, device=self.device)
labels_real = torch.full(self.discriminator.get_output_shape(inputs_d_real.shape), LABEL_REAL, device=self.device)
# ======================= Discriminator =====================================
# compute discriminator loss for fake mu maps
# detach is called so that gradients are not computed for the generator
......@@ -363,6 +362,7 @@ if __name__ == "__main__":
)
discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
# discriminator = PatchDiscriminator(in_channels=2)
discriminator = discriminator.to(device)
optimizer = torch.optim.Adam(
discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment