diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py index 9cfbfc6cbdf31c203ae30e595be956b9c41741c9..c9362bd8bd766b600d1325e8ecc86c0c36a412e7 100644 --- a/mu_map/dataset/normalization.py +++ b/mu_map/dataset/normalization.py @@ -43,6 +43,23 @@ class GaussianNormTransform(Transform): return norm_gaussian(inputs), outputs_expected +norm_choices = ["max", "mean", "gaussian"] + + +def norm_by_str(norm: str): + if norm is None: + return None + + if norm == "mean": + return MeanNormTransform() + elif norm == "max": + return MaxNormTransform() + elif norm == "gaussian": + return GaussianNormTransform() + + raise ValueError(f"Unknown normalization {norm}") + + __all__ = [ norm_max.__name__, norm_mean.__name__,