diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py
index b62c07efad68fa72b2b9ea8e4a0b88f4e241f28e..51f62ef777c76d989f1cb6667646d7db799fc758 100644
--- a/mu_map/training/loss.py
+++ b/mu_map/training/loss.py
@@ -1,3 +1,5 @@
+from typing import List
+
 import torch
 import torch.nn as nn
 
@@ -17,11 +19,97 @@ class GradientDifferenceLoss(nn.Module):
         return gradient_diff / inputs.numel()
 
 
+def loss_by_string(loss_str: str) -> nn.Module:
+    """
+    Retrieve a loss function defined by a string.
+    E.g., L1 returns the torch module of the l1 loss function.
+
+    :param loss_str: loss function defined as a string
+    :returns: an executable loss function
+    """
+    loss_str = loss_str.lower()
+    if loss_str == "l1":
+        return nn.L1Loss(reduction="mean")
+    elif loss_str == "l2" or loss_str == "mse":
+        return nn.MSELoss(reduction="mean")
+    elif loss_str == "gdl":
+        return GradientDifferenceLoss()
+    else:
+        raise ValueError(f"Unknown loss function: {loss_str}")
+
+
+class WeightedLoss(nn.Module):
+    """
+    Definition of a weighted loss consisting of a number of losses
+    with according weights.
+
+    :param losses: the losses to be summed and weighted
+    :param weights: weights for each loss function
+    """
+
+    def __init__(self, losses: List[nn.Module], weights: List[float]):
+        super().__init__()
+
+        assert len(losses) == len(
+            weights
+        ), f"There is a different number of losses {len(losses)} compared to weights {len(weights)}"
+
+        self.losses = losses
+        self.weights = weights
+
+    def forward(self, outputs: torch.Tensor, targets: torch.Tensor):
+        loss = 0.0
+        for loss_func, weight in zip(self.losses, self.weights):
+            loss += weight * loss_func(outputs, targets)
+        return loss
+
+    def __repr__(self):
+        return " + ".join(
+            map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
+        )
+
+    @classmethod
+    def from_str(cls, loss_func_str: str):
+        """
+        Parse a weighted loss function from a string.
+        E.g.: 0.1*gdl+0.9*l2
+        """
+        addends = loss_func_str.split("+")
+
+        losses, weights = [], []
+        for addend in loss_func_str.split("+"):
+            factors = addend.strip().split("*")
+
+            if len(factors) == 1:
+                weights.append(1.0)
+                losses.append(loss_by_string(factors[0]))
+            else:
+                weights.append(float(factors[0]))
+                losses.append(loss_by_string(factors[1]))
+
+        return cls(losses, weights)
+
+
 if __name__ == "__main__":
-    torch.manual_seed(10)
-    inputs = torch.rand((4, 1, 32, 64, 64))
-    targets = torch.rand((4, 1, 32, 64, 64))
+    import argparse
+
+    parser = argparse.ArgumentParser(
+        description="Test building a loss function from a string"
+    )
+    parser.add_argument(
+        "loss", type=str, help="description of a loss function, e.g., 0.75*L1+0.25*GDL"
+    )
+    args = parser.parse_args()
+
+    criterion = WeightedLoss.from_str(args.loss)
+
+    inputs = torch.Tensor([[1, 2], [3, 4]])
+    inputs = inputs.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
+    targets = torch.ones(inputs.shape)
 
-    criterion = GradientDifferenceLoss()
     loss = criterion(inputs, targets)
+
+    print(f"Inputs:\n{inputs}")
+    print(f"Targets:\n{targets}")
+    print(f"Criterion: {criterion}")
     print(f"Loss: {loss.item():.6f}")