diff --git a/mu_map/training/default.py b/mu_map/training/default.py index 6ef94d9d2e0b29213cb9fdd6cf950866ee54f892..e328f453b8f406e37d462bf27c4b77f212dd95d2 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -110,7 +110,12 @@ if __name__ == "__main__": import argparse from mu_map.dataset.patches import MuMapPatchDataset - from mu_map.dataset.normalization import MeanNormTransform + from mu_map.dataset.normalization import ( + MeanNormTransform, + MaxNormTransform, + GaussianNormTransform, + ) + from mu_map.dataset.transform import ScaleTransform from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet @@ -135,6 +140,19 @@ if __name__ == "__main__": default="data/initial/", help="the directory where the dataset for training is found", ) + parser.add_argument( + "--output_scale", + type=float, + default=1.0, + help="scale the attenuation map by this coefficient", + ) + parser.add_argument( + "--input_norm", + type=str, + choices=["none", "mean", "max", "gaussian"], + default="mean", + help="type of normalization applied to the reconstructions", + ) # Training Args parser.add_argument( @@ -200,6 +218,13 @@ if __name__ == "__main__": args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) if not os.path.exists(args.snapshot_dir): os.mkdir(args.snapshot_dir) + else: + if len(os.listdir(args.snapshot_dir)) > 0: + print( + f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!" + ) + print(f" Exit so that data is not accidentally overwritten!") + exit(1) args.logfile = os.path.join(args.output_dir, args.logfile) @@ -210,6 +235,16 @@ if __name__ == "__main__": model = UNet(in_channels=1, features=args.features) model = model.to(device) + transform_normalization = None + if args.input_norm == "mean": + transform_normalization = MeanNormTransform() + elif args.input_norm == "max": + transform_normalization = MaxNormTransform() + elif args.input_norm == "gaussian": + transform_normalization = GaussianNormTransform() + + transform_augmentation = ScaleTransform(scale_outputs=args.output_scale) + data_loaders = {} for split in ["train", "validation"]: dataset = MuMapPatchDataset(