From 160d0dffea6ad589a3086ce81dc840cc0d570b53 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 11 Oct 2022 11:33:10 +0200 Subject: [PATCH] add function to retrieve a normalization by str --- mu_map/dataset/normalization.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py index 9cfbfc6..c9362bd 100644 --- a/mu_map/dataset/normalization.py +++ b/mu_map/dataset/normalization.py @@ -43,6 +43,23 @@ class GaussianNormTransform(Transform): return norm_gaussian(inputs), outputs_expected +norm_choices = ["max", "mean", "gaussian"] + + +def norm_by_str(norm: str): + if norm is None: + return None + + if norm == "mean": + return MeanNormTransform() + elif norm == "max": + return MaxNormTransform() + elif norm == "gaussian": + return GaussianNormTransform() + + raise ValueError(f"Unknown normalization {norm}") + + __all__ = [ norm_max.__name__, norm_mean.__name__, -- GitLab