Skip to content
Snippets Groups Projects
Commit 828569e1 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement comparator functions for weighted loss and normalization transform

parent ff167f5d
No related branches found
No related tags found
No related merge requests found
......@@ -2,7 +2,7 @@
Module containing normalization methods either as functions
or transformers.
"""
from typing import Callable, Optional, Tuple
from typing import Any, Callable, Optional, Tuple
from torch import Tensor
......@@ -69,6 +69,23 @@ class NormTransform(Transform):
"""
return (self.norm_func(tensors[0]), *tensors[1:])
def __eq__(self, other: Any) -> bool:
"""
Implementation of the comparison operator.
This implementation just checks that self and other are of
the same class.
Parameters
----------
other: Any
Returns
-------
bool
"""
return self.__class__ == other.__class__
class MaxNormTransform(NormTransform):
"""
......
from typing import List
from typing import Any, List
import torch
import torch.nn as nn
......@@ -68,6 +68,33 @@ class WeightedLoss(nn.Module):
map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
)
def __eq__(self, other: Any) -> bool:
"""
Implementation of the comparison operator.
This implementation makes sure that both weighted losses consist
of the same loss classes with the same weights. Note that this
is not a `smart` implementation as the order matters. For example,
`L2+GDL` is not equal to `GDL+L2`.
Parameters
----------
other: Any
Returns
-------
bool
"""
if self.__class__ != other.__class__:
return False
if len(self.losses) != len(other.losses):
return False
loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses))
loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses))
return loss_classes_self == loss_classes_other and self.weights == other.weights
@classmethod
def from_str(cls, loss_func_str: str):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment