diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py
index 5ae800f4dde302f4a536c0f48f4e55adb7cb0538..bfefc39fe973072c5bcf7661b62d1c5d72816c8f 100644
--- a/mu_map/dataset/normalization.py
+++ b/mu_map/dataset/normalization.py
@@ -2,7 +2,7 @@
 Module containing normalization methods either as functions
 or transformers.
 """
-from typing import Callable, Optional, Tuple
+from typing import Any, Callable, Optional, Tuple
 
 from torch import Tensor
 
@@ -69,6 +69,23 @@ class NormTransform(Transform):
         """
         return (self.norm_func(tensors[0]), *tensors[1:])
 
+    def __eq__(self, other: Any) -> bool:
+        """
+        Implementation of the comparison operator.
+
+        This implementation just checks that self and other are of
+        the same class.
+
+        Parameters
+        ----------
+        other: Any
+
+        Returns
+        -------
+        bool
+        """
+        return self.__class__ == other.__class__
+
 
 class MaxNormTransform(NormTransform):
     """
diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py
index 6b175d1d6b21d84ba2648eb5cbef5c4c41ff86e8..7272b3b2b4f10f3aec4f5ef06e2a44c20eca4412 100644
--- a/mu_map/training/loss.py
+++ b/mu_map/training/loss.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import Any, List
 
 import torch
 import torch.nn as nn
@@ -68,6 +68,33 @@ class WeightedLoss(nn.Module):
             map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
         )
 
+    def __eq__(self, other: Any) -> bool:
+        """
+        Implementation of the comparison operator.
+
+        This implementation makes sure that both weighted losses consist
+        of the same loss classes with the same weights. Note that this
+        is not a `smart` implementation as the order matters. For example,
+        `L2+GDL` is not equal to `GDL+L2`.
+
+        Parameters
+        ----------
+        other: Any
+
+        Returns
+        -------
+        bool
+        """
+        if self.__class__ != other.__class__:
+            return False
+
+        if len(self.losses) != len(other.losses):
+            return False
+
+        loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses))
+        loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses))
+        return loss_classes_self == loss_classes_other and self.weights == other.weights
+
     @classmethod
     def from_str(cls, loss_func_str: str):
         """