From 828569e12619cb146afecab4f34dd66c33fac273 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Thu, 26 Jan 2023 15:35:15 +0100 Subject: [PATCH] implement comparator functions for weighted loss and normalization transform --- mu_map/dataset/normalization.py | 19 ++++++++++++++++++- mu_map/training/loss.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/mu_map/dataset/normalization.py b/mu_map/dataset/normalization.py index 5ae800f..bfefc39 100644 --- a/mu_map/dataset/normalization.py +++ b/mu_map/dataset/normalization.py @@ -2,7 +2,7 @@ Module containing normalization methods either as functions or transformers. """ -from typing import Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple from torch import Tensor @@ -69,6 +69,23 @@ class NormTransform(Transform): """ return (self.norm_func(tensors[0]), *tensors[1:]) + def __eq__(self, other: Any) -> bool: + """ + Implementation of the comparison operator. + + This implementation just checks that self and other are of + the same class. + + Parameters + ---------- + other: Any + + Returns + ------- + bool + """ + return self.__class__ == other.__class__ + class MaxNormTransform(NormTransform): """ diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py index 6b175d1..7272b3b 100644 --- a/mu_map/training/loss.py +++ b/mu_map/training/loss.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, List import torch import torch.nn as nn @@ -68,6 +68,33 @@ class WeightedLoss(nn.Module): map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses)) ) + def __eq__(self, other: Any) -> bool: + """ + Implementation of the comparison operator. + + This implementation makes sure that both weighted losses consist + of the same loss classes with the same weights. Note that this + is not a `smart` implementation as the order matters. For example, + `L2+GDL` is not equal to `GDL+L2`. + + Parameters + ---------- + other: Any + + Returns + ------- + bool + """ + if self.__class__ != other.__class__: + return False + + if len(self.losses) != len(other.losses): + return False + + loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses)) + loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses)) + return loss_classes_self == loss_classes_other and self.weights == other.weights + @classmethod def from_str(cls, loss_func_str: str): """ -- GitLab