Skip to content
Snippets Groups Projects
Commit a9a61785 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update parameters in default training

parent 2c050de5
No related branches found
No related tags found
No related merge requests found
...@@ -109,8 +109,6 @@ class Training: ...@@ -109,8 +109,6 @@ class Training:
if __name__ == "__main__": if __name__ == "__main__":
import argparse import argparse
# from mu_map.dataset.mock import MuMapMockDataset
# from mu_map.dataset.default import MuMapDataset
from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import MeanNormTransform from mu_map.dataset.normalization import MeanNormTransform
from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.logging import add_logging_args, get_logger_by_args
...@@ -126,7 +124,7 @@ if __name__ == "__main__": ...@@ -126,7 +124,7 @@ if __name__ == "__main__":
"--features", "--features",
type=int, type=int,
nargs="+", nargs="+",
default=[8, 16], default=[64, 128, 256, 512],
help="number of features in the layers of the UNet structure", help="number of features in the layers of the UNet structure",
) )
...@@ -142,7 +140,7 @@ if __name__ == "__main__": ...@@ -142,7 +140,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
default=4, default=64,
help="the batch size used for training", help="the batch size used for training",
) )
parser.add_argument( parser.add_argument(
...@@ -154,7 +152,7 @@ if __name__ == "__main__": ...@@ -154,7 +152,7 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--epochs", "--epochs",
type=int, type=int,
default=10, default=100,
help="the number of epochs for which the model is trained", help="the number of epochs for which the model is trained",
) )
parser.add_argument( parser.add_argument(
...@@ -164,7 +162,7 @@ if __name__ == "__main__": ...@@ -164,7 +162,7 @@ if __name__ == "__main__":
help="the device (cpu or gpu) with which the training is performed", help="the device (cpu or gpu) with which the training is performed",
) )
parser.add_argument( parser.add_argument(
"--lr", type=float, default=0.1, help="the initial learning rate for training" "--lr", type=float, default=0.001, help="the initial learning rate for training"
) )
parser.add_argument( parser.add_argument(
"--lr_decay_factor", "--lr_decay_factor",
...@@ -208,6 +206,7 @@ if __name__ == "__main__": ...@@ -208,6 +206,7 @@ if __name__ == "__main__":
device = torch.device(args.device) device = torch.device(args.device)
logger = get_logger_by_args(args) logger = get_logger_by_args(args)
model = UNet(in_channels=1, features=args.features) model = UNet(in_channels=1, features=args.features)
model = model.to(device)
data_loaders = {} data_loaders = {}
for split in ["train", "validation"]: for split in ["train", "validation"]:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment