Skip to content
Snippets Groups Projects
Commit 102f3e46 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement selection of normalization and scale in default training

parent e95e5b6f
No related branches found
No related tags found
No related merge requests found
...@@ -110,7 +110,12 @@ if __name__ == "__main__": ...@@ -110,7 +110,12 @@ if __name__ == "__main__":
import argparse import argparse
from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import MeanNormTransform from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import ScaleTransform
from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet from mu_map.models.unet import UNet
...@@ -135,6 +140,19 @@ if __name__ == "__main__": ...@@ -135,6 +140,19 @@ if __name__ == "__main__":
default="data/initial/", default="data/initial/",
help="the directory where the dataset for training is found", help="the directory where the dataset for training is found",
) )
parser.add_argument(
"--output_scale",
type=float,
default=1.0,
help="scale the attenuation map by this coefficient",
)
parser.add_argument(
"--input_norm",
type=str,
choices=["none", "mean", "max", "gaussian"],
default="mean",
help="type of normalization applied to the reconstructions",
)
# Training Args # Training Args
parser.add_argument( parser.add_argument(
...@@ -200,6 +218,13 @@ if __name__ == "__main__": ...@@ -200,6 +218,13 @@ if __name__ == "__main__":
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir): if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir) os.mkdir(args.snapshot_dir)
else:
if len(os.listdir(args.snapshot_dir)) > 0:
print(
f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
)
print(f" Exit so that data is not accidentally overwritten!")
exit(1)
args.logfile = os.path.join(args.output_dir, args.logfile) args.logfile = os.path.join(args.output_dir, args.logfile)
...@@ -210,6 +235,16 @@ if __name__ == "__main__": ...@@ -210,6 +235,16 @@ if __name__ == "__main__":
model = UNet(in_channels=1, features=args.features) model = UNet(in_channels=1, features=args.features)
model = model.to(device) model = model.to(device)
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
transform_augmentation = ScaleTransform(scale_outputs=args.output_scale)
data_loaders = {} data_loaders = {}
for split in ["train", "validation"]: for split in ["train", "validation"]:
dataset = MuMapPatchDataset( dataset = MuMapPatchDataset(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment