From 828569e12619cb146afecab4f34dd66c33fac273 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Thu, 26 Jan 2023 15:35:15 +0100
Subject: [PATCH] implement comparator functions for weighted loss and
 normalization transform

---
 mu_map/dataset/normalization.py | 19 ++++++++++++++++++-
 mu_map/training/loss.py         | 29 ++++++++++++++++++++++++++++-
 2 files changed, 46 insertions(+), 2 deletions(-)

diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py
index 5ae800f..bfefc39 100644
--- a/mu_map/dataset/normalization.py
+++ b/mu_map/dataset/normalization.py
@@ -2,7 +2,7 @@
 Module containing normalization methods either as functions
 or transformers.
 """
-from typing import Callable, Optional, Tuple
+from typing import Any, Callable, Optional, Tuple
 
 from torch import Tensor
 
@@ -69,6 +69,23 @@ class NormTransform(Transform):
         """
         return (self.norm_func(tensors[0]), *tensors[1:])
 
+    def __eq__(self, other: Any) -> bool:
+        """
+        Implementation of the comparison operator.
+
+        This implementation just checks that self and other are of
+        the same class.
+
+        Parameters
+        ----------
+        other: Any
+
+        Returns
+        -------
+        bool
+        """
+        return self.__class__ == other.__class__
+
 
 class MaxNormTransform(NormTransform):
     """
diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py
index 6b175d1..7272b3b 100644
--- a/mu_map/training/loss.py
+++ b/mu_map/training/loss.py
@@ -1,4 +1,4 @@
-from typing import List
+from typing import Any, List
 
 import torch
 import torch.nn as nn
@@ -68,6 +68,33 @@ class WeightedLoss(nn.Module):
             map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
         )
 
+    def __eq__(self, other: Any) -> bool:
+        """
+        Implementation of the comparison operator.
+
+        This implementation makes sure that both weighted losses consist
+        of the same loss classes with the same weights. Note that this
+        is not a `smart` implementation as the order matters. For example,
+        `L2+GDL` is not equal to `GDL+L2`.
+
+        Parameters
+        ----------
+        other: Any
+
+        Returns
+        -------
+        bool
+        """
+        if self.__class__ != other.__class__:
+            return False
+
+        if len(self.losses) != len(other.losses):
+            return False
+
+        loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses))
+        loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses))
+        return loss_classes_self == loss_classes_other and self.weights == other.weights
+
     @classmethod
     def from_str(cls, loss_func_str: str):
         """
-- 
GitLab