From 5f983873d2717c7a0cd92a663c238ecafc9f96d1 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 11 Oct 2022 09:13:27 +0200 Subject: [PATCH] cgan training allows to init generator weights --- mu_map/training/cgan.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 25d9fc2..6352b21 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -295,7 +295,7 @@ if __name__ == "__main__": parser.add_argument( "--number_of_patches", type=int, - default=1, + default=100, help="number of patches extracted for each image", ) parser.add_argument( @@ -379,6 +379,11 @@ if __name__ == "__main__": default=10, help="frequency in epochs at which snapshots are stored", ) + parser.add_argument( + "--generator_weights", + type=str, + help="use pre-trained weights for the generator", + ) # Logging Args add_logging_args(parser, defaults={"--logfile": "train.log"}) @@ -416,6 +421,9 @@ if __name__ == "__main__": generator = UNet(in_channels=1, features=args.features) generator = generator.to(device) + if args.generator_weights: + logger.debug(f"Load generator weights from {args.generator_weights}") + generator.load_state_dict(torch.load(args.generator_weights, map_location=device)) transform_normalization = None if args.input_norm == "mean": -- GitLab