From a9a61785a7df25a322c73f5017cbf0e900596e76 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 4 Oct 2022 08:55:10 +0200 Subject: [PATCH] update parameters in default training --- mu_map/training/default.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mu_map/training/default.py b/mu_map/training/default.py index 5fc5066..90ab158 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -109,8 +109,6 @@ class Training: if __name__ == "__main__": import argparse - # from mu_map.dataset.mock import MuMapMockDataset - # from mu_map.dataset.default import MuMapDataset from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.normalization import MeanNormTransform from mu_map.logging import add_logging_args, get_logger_by_args @@ -126,7 +124,7 @@ if __name__ == "__main__": "--features", type=int, nargs="+", - default=[8, 16], + default=[64, 128, 256, 512], help="number of features in the layers of the UNet structure", ) @@ -142,7 +140,7 @@ if __name__ == "__main__": parser.add_argument( "--batch_size", type=int, - default=4, + default=64, help="the batch size used for training", ) parser.add_argument( @@ -154,7 +152,7 @@ if __name__ == "__main__": parser.add_argument( "--epochs", type=int, - default=10, + default=100, help="the number of epochs for which the model is trained", ) parser.add_argument( @@ -164,7 +162,7 @@ if __name__ == "__main__": help="the device (cpu or gpu) with which the training is performed", ) parser.add_argument( - "--lr", type=float, default=0.1, help="the initial learning rate for training" + "--lr", type=float, default=0.001, help="the initial learning rate for training" ) parser.add_argument( "--lr_decay_factor", @@ -208,6 +206,7 @@ if __name__ == "__main__": device = torch.device(args.device) logger = get_logger_by_args(args) model = UNet(in_channels=1, features=args.features) + model = model.to(device) data_loaders = {} for split in ["train", "validation"]: -- GitLab