From a9a61785a7df25a322c73f5017cbf0e900596e76 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 4 Oct 2022 08:55:10 +0200
Subject: [PATCH] update parameters in default training

---
 mu_map/training/default.py | 11 +++++------
 1 file changed, 5 insertions(+), 6 deletions(-)

diff --git a/mu_map/training/default.py b/mu_map/training/default.py
index 5fc5066..90ab158 100644
--- a/mu_map/training/default.py
+++ b/mu_map/training/default.py
@@ -109,8 +109,6 @@ class Training:
 if __name__ == "__main__":
     import argparse
 
-    # from mu_map.dataset.mock import MuMapMockDataset
-    # from mu_map.dataset.default import MuMapDataset
     from mu_map.dataset.patches import MuMapPatchDataset
     from mu_map.dataset.normalization import MeanNormTransform
     from mu_map.logging import add_logging_args, get_logger_by_args
@@ -126,7 +124,7 @@ if __name__ == "__main__":
         "--features",
         type=int,
         nargs="+",
-        default=[8, 16],
+        default=[64, 128, 256, 512],
         help="number of features in the layers of the UNet structure",
     )
 
@@ -142,7 +140,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--batch_size",
         type=int,
-        default=4,
+        default=64,
         help="the batch size used for training",
     )
     parser.add_argument(
@@ -154,7 +152,7 @@ if __name__ == "__main__":
     parser.add_argument(
         "--epochs",
         type=int,
-        default=10,
+        default=100,
         help="the number of epochs for which the model is trained",
     )
     parser.add_argument(
@@ -164,7 +162,7 @@ if __name__ == "__main__":
         help="the device (cpu or gpu) with which the training is performed",
     )
     parser.add_argument(
-        "--lr", type=float, default=0.1, help="the initial learning rate for training"
+        "--lr", type=float, default=0.001, help="the initial learning rate for training"
     )
     parser.add_argument(
         "--lr_decay_factor",
@@ -208,6 +206,7 @@ if __name__ == "__main__":
     device = torch.device(args.device)
     logger = get_logger_by_args(args)
     model = UNet(in_channels=1, features=args.features)
+    model = model.to(device)
 
     data_loaders = {}
     for split in ["train", "validation"]:
-- 
GitLab