From 102f3e4630784ecb704e2bf7bd07a96b2a0f7f53 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 4 Oct 2022 09:46:25 +0200
Subject: [PATCH] implement selection of normalization and scale in default
 training

---
 mu_map/training/default.py | 37 ++++++++++++++++++++++++++++++++++++-
 1 file changed, 36 insertions(+), 1 deletion(-)

diff --git a/mu_map/training/default.py b/mu_map/training/default.py
index 6ef94d9..e328f45 100644
--- a/mu_map/training/default.py
+++ b/mu_map/training/default.py
@@ -110,7 +110,12 @@ if __name__ == "__main__":
     import argparse
 
     from mu_map.dataset.patches import MuMapPatchDataset
-    from mu_map.dataset.normalization import MeanNormTransform
+    from mu_map.dataset.normalization import (
+        MeanNormTransform,
+        MaxNormTransform,
+        GaussianNormTransform,
+    )
+    from mu_map.dataset.transform import ScaleTransform
     from mu_map.logging import add_logging_args, get_logger_by_args
     from mu_map.models.unet import UNet
 
@@ -135,6 +140,19 @@ if __name__ == "__main__":
         default="data/initial/",
         help="the directory where the dataset for training is found",
     )
+    parser.add_argument(
+        "--output_scale",
+        type=float,
+        default=1.0,
+        help="scale the attenuation map by this coefficient",
+    )
+    parser.add_argument(
+        "--input_norm",
+        type=str,
+        choices=["none", "mean", "max", "gaussian"],
+        default="mean",
+        help="type of normalization applied to the reconstructions",
+    )
 
     # Training Args
     parser.add_argument(
@@ -200,6 +218,13 @@ if __name__ == "__main__":
     args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
     if not os.path.exists(args.snapshot_dir):
         os.mkdir(args.snapshot_dir)
+    else:
+        if len(os.listdir(args.snapshot_dir)) > 0:
+            print(
+                f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
+            )
+            print(f"           Exit so that data is not accidentally overwritten!")
+            exit(1)
 
     args.logfile = os.path.join(args.output_dir, args.logfile)
 
@@ -210,6 +235,16 @@ if __name__ == "__main__":
     model = UNet(in_channels=1, features=args.features)
     model = model.to(device)
 
+    transform_normalization = None
+    if args.input_norm == "mean":
+        transform_normalization = MeanNormTransform()
+    elif args.input_norm == "max":
+        transform_normalization = MaxNormTransform()
+    elif args.input_norm == "gaussian":
+        transform_normalization = GaussianNormTransform()
+
+    transform_augmentation = ScaleTransform(scale_outputs=args.output_scale)
+
     data_loaders = {}
     for split in ["train", "validation"]:
         dataset = MuMapPatchDataset(
-- 
GitLab