diff --git a/mu_map/training/default.py b/mu_map/training/default.py index 5fc5066d7e12d7c043e83bb7bce2466624700686..90ab158e2fec846cb4ef7da077a1e527dafc0744 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -109,8 +109,6 @@ class Training: if __name__ == "__main__": import argparse - # from mu_map.dataset.mock import MuMapMockDataset - # from mu_map.dataset.default import MuMapDataset from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.normalization import MeanNormTransform from mu_map.logging import add_logging_args, get_logger_by_args @@ -126,7 +124,7 @@ if __name__ == "__main__": "--features", type=int, nargs="+", - default=[8, 16], + default=[64, 128, 256, 512], help="number of features in the layers of the UNet structure", ) @@ -142,7 +140,7 @@ if __name__ == "__main__": parser.add_argument( "--batch_size", type=int, - default=4, + default=64, help="the batch size used for training", ) parser.add_argument( @@ -154,7 +152,7 @@ if __name__ == "__main__": parser.add_argument( "--epochs", type=int, - default=10, + default=100, help="the number of epochs for which the model is trained", ) parser.add_argument( @@ -164,7 +162,7 @@ if __name__ == "__main__": help="the device (cpu or gpu) with which the training is performed", ) parser.add_argument( - "--lr", type=float, default=0.1, help="the initial learning rate for training" + "--lr", type=float, default=0.001, help="the initial learning rate for training" ) parser.add_argument( "--lr_decay_factor", @@ -208,6 +206,7 @@ if __name__ == "__main__": device = torch.device(args.device) logger = get_logger_by_args(args) model = UNet(in_channels=1, features=args.features) + model = model.to(device) data_loaders = {} for split in ["train", "validation"]: