diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 25d9fc20e35fb7d6f73707d6cac6f4e77356fe06..6352b21c27efe0bf8af81a78cb108aec329ca6fb 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -295,7 +295,7 @@ if __name__ == "__main__": parser.add_argument( "--number_of_patches", type=int, - default=1, + default=100, help="number of patches extracted for each image", ) parser.add_argument( @@ -379,6 +379,11 @@ if __name__ == "__main__": default=10, help="frequency in epochs at which snapshots are stored", ) + parser.add_argument( + "--generator_weights", + type=str, + help="use pre-trained weights for the generator", + ) # Logging Args add_logging_args(parser, defaults={"--logfile": "train.log"}) @@ -416,6 +421,9 @@ if __name__ == "__main__": generator = UNet(in_channels=1, features=args.features) generator = generator.to(device) + if args.generator_weights: + logger.debug(f"Load generator weights from {args.generator_weights}") + generator.load_state_dict(torch.load(args.generator_weights, map_location=device)) transform_normalization = None if args.input_norm == "mean":