Skip to content
Snippets Groups Projects
Commit 828569e1 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement comparator functions for weighted loss and normalization transform

parent ff167f5d
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
Module containing normalization methods either as functions Module containing normalization methods either as functions
or transformers. or transformers.
""" """
from typing import Callable, Optional, Tuple from typing import Any, Callable, Optional, Tuple
from torch import Tensor from torch import Tensor
...@@ -69,6 +69,23 @@ class NormTransform(Transform): ...@@ -69,6 +69,23 @@ class NormTransform(Transform):
""" """
return (self.norm_func(tensors[0]), *tensors[1:]) return (self.norm_func(tensors[0]), *tensors[1:])
def __eq__(self, other: Any) -> bool:
"""
Implementation of the comparison operator.
This implementation just checks that self and other are of
the same class.
Parameters
----------
other: Any
Returns
-------
bool
"""
return self.__class__ == other.__class__
class MaxNormTransform(NormTransform): class MaxNormTransform(NormTransform):
""" """
......
from typing import List from typing import Any, List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -68,6 +68,33 @@ class WeightedLoss(nn.Module): ...@@ -68,6 +68,33 @@ class WeightedLoss(nn.Module):
map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses)) map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
) )
def __eq__(self, other: Any) -> bool:
"""
Implementation of the comparison operator.
This implementation makes sure that both weighted losses consist
of the same loss classes with the same weights. Note that this
is not a `smart` implementation as the order matters. For example,
`L2+GDL` is not equal to `GDL+L2`.
Parameters
----------
other: Any
Returns
-------
bool
"""
if self.__class__ != other.__class__:
return False
if len(self.losses) != len(other.losses):
return False
loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses))
loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses))
return loss_classes_self == loss_classes_other and self.weights == other.weights
@classmethod @classmethod
def from_str(cls, loss_func_str: str): def from_str(cls, loss_func_str: str):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment