From 5f983873d2717c7a0cd92a663c238ecafc9f96d1 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 11 Oct 2022 09:13:27 +0200
Subject: [PATCH] cgan training allows to init generator weights

---
 mu_map/training/cgan.py | 10 +++++++++-
 1 file changed, 9 insertions(+), 1 deletion(-)

diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index 25d9fc2..6352b21 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -295,7 +295,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--number_of_patches",
         type=int,
-        default=1,
+        default=100,
         help="number of patches extracted for each image",
     )
     parser.add_argument(
@@ -379,6 +379,11 @@ if __name__ == "__main__":
         default=10,
         help="frequency in epochs at which snapshots are stored",
     )
+    parser.add_argument(
+        "--generator_weights",
+        type=str,
+        help="use pre-trained weights for the generator",
+    )
 
     # Logging Args
     add_logging_args(parser, defaults={"--logfile": "train.log"})
@@ -416,6 +421,9 @@ if __name__ == "__main__":
 
     generator = UNet(in_channels=1, features=args.features)
     generator = generator.to(device)
+    if args.generator_weights:
+        logger.debug(f"Load generator weights from {args.generator_weights}")
+        generator.load_state_dict(torch.load(args.generator_weights, map_location=device))
 
     transform_normalization = None
     if args.input_norm == "mean":
-- 
GitLab