Skip to content
Snippets Groups Projects
Commit 102f3e46 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement selection of normalization and scale in default training

parent e95e5b6f
No related merge requests found
......@@ -110,7 +110,12 @@ if __name__ == "__main__":
import argparse
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import MeanNormTransform
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import ScaleTransform
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
......@@ -135,6 +140,19 @@ if __name__ == "__main__":
default="data/initial/",
help="the directory where the dataset for training is found",
)
parser.add_argument(
"--output_scale",
type=float,
default=1.0,
help="scale the attenuation map by this coefficient",
)
parser.add_argument(
"--input_norm",
type=str,
choices=["none", "mean", "max", "gaussian"],
default="mean",
help="type of normalization applied to the reconstructions",
)
# Training Args
parser.add_argument(
......@@ -200,6 +218,13 @@ if __name__ == "__main__":
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
else:
if len(os.listdir(args.snapshot_dir)) > 0:
print(
f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
)
print(f" Exit so that data is not accidentally overwritten!")
exit(1)
args.logfile = os.path.join(args.output_dir, args.logfile)
......@@ -210,6 +235,16 @@ if __name__ == "__main__":
model = UNet(in_channels=1, features=args.features)
model = model.to(device)
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
transform_augmentation = ScaleTransform(scale_outputs=args.output_scale)
data_loaders = {}
for split in ["train", "validation"]:
dataset = MuMapPatchDataset(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment