diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py index 5ae800f4dde302f4a536c0f48f4e55adb7cb0538..bfefc39fe973072c5bcf7661b62d1c5d72816c8f 100644 --- a/mu_map/dataset/normalization.py +++ b/mu_map/dataset/normalization.py @@ -2,7 +2,7 @@ Module containing normalization methods either as functions or transformers. """ -from typing import Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple from torch import Tensor @@ -69,6 +69,23 @@ class NormTransform(Transform): """ return (self.norm_func(tensors[0]), *tensors[1:]) + def __eq__(self, other: Any) -> bool: + """ + Implementation of the comparison operator. + + This implementation just checks that self and other are of + the same class. + + Parameters + ---------- + other: Any + + Returns + ------- + bool + """ + return self.__class__ == other.__class__ + class MaxNormTransform(NormTransform): """ diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py index 6b175d1d6b21d84ba2648eb5cbef5c4c41ff86e8..7272b3b2b4f10f3aec4f5ef06e2a44c20eca4412 100644 --- a/mu_map/training/loss.py +++ b/mu_map/training/loss.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, List import torch import torch.nn as nn @@ -68,6 +68,33 @@ class WeightedLoss(nn.Module): map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses)) ) + def __eq__(self, other: Any) -> bool: + """ + Implementation of the comparison operator. + + This implementation makes sure that both weighted losses consist + of the same loss classes with the same weights. Note that this + is not a `smart` implementation as the order matters. For example, + `L2+GDL` is not equal to `GDL+L2`. + + Parameters + ---------- + other: Any + + Returns + ------- + bool + """ + if self.__class__ != other.__class__: + return False + + if len(self.losses) != len(other.losses): + return False + + loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses)) + loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses)) + return loss_classes_self == loss_classes_other and self.weights == other.weights + @classmethod def from_str(cls, loss_func_str: str): """