diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 0b34ff39603ba244bef034b95f2c019dfce750e0..21de59f985a6363e7c6c1073490c9ab9bf30849c 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -421,8 +421,8 @@ if __name__ == "__main__": torch.manual_seed(args.seed) np.random.seed(args.seed) - # discriminator = Discriminator(in_channels=2, input_size=args.patch_size) - discriminator = PatchDiscriminator(in_channels=2, input_size=args.patch_size) + discriminator = Discriminator(in_channels=2, input_size=args.patch_size) + # discriminator = PatchDiscriminator(in_channels=2, input_size=args.patch_size) discriminator = discriminator.to(device) generator = UNet(in_channels=1, features=args.features)