Newer
Older

Tamino Huxohl
committed
from typing import Any, List
import torch
import torch.nn as nn
class GradientDifferenceLoss(nn.Module):
"""
Gradient Difference Loss (GDL) inspired by https://github.com/mmany/pytorch-GDL/blob/main/custom_loss_functions.py.
It is modified to deal with 5D tensors (batch_size, channels, z, y, x).
"""
def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
gradient_diff_z = (inputs.diff(dim=2) - targets.diff(axis=2)).pow(2).sum()
gradient_diff_y = (inputs.diff(dim=3) - targets.diff(axis=3)).pow(2).sum()
gradient_diff_x = (inputs.diff(dim=4) - targets.diff(axis=4)).pow(2).sum()
gradient_diff = gradient_diff_x + gradient_diff_y + gradient_diff_z
return gradient_diff / inputs.numel()
def loss_by_string(loss_str: str) -> nn.Module:
"""
Retrieve a loss function defined by a string.
E.g., L1 returns the torch module of the l1 loss function.
:param loss_str: loss function defined as a string
:returns: an executable loss function
"""
loss_str = loss_str.lower()
if "l1" in loss_str:
return nn.L1Loss(reduction="mean")
elif "l2" in loss_str or "mse" in loss_str:
return nn.MSELoss(reduction="mean")
elif "gdl" in loss_str or "gradientdifferenceloss" in loss_str:
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
return GradientDifferenceLoss()
else:
raise ValueError(f"Unknown loss function: {loss_str}")
class WeightedLoss(nn.Module):
"""
Definition of a weighted loss consisting of a number of losses
with according weights.
:param losses: the losses to be summed and weighted
:param weights: weights for each loss function
"""
def __init__(self, losses: List[nn.Module], weights: List[float]):
super().__init__()
assert len(losses) == len(
weights
), f"There is a different number of losses {len(losses)} compared to weights {len(weights)}"
self.losses = losses
self.weights = weights
def forward(self, outputs: torch.Tensor, targets: torch.Tensor):
loss = 0.0
for loss_func, weight in zip(self.losses, self.weights):
loss += weight * loss_func(outputs, targets)
return loss
def __repr__(self):
return " + ".join(
map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
)

Tamino Huxohl
committed
def __eq__(self, other: Any) -> bool:
"""
Implementation of the comparison operator.
This implementation makes sure that both weighted losses consist
of the same loss classes with the same weights. Note that this
is not a `smart` implementation as the order matters. For example,
`L2+GDL` is not equal to `GDL+L2`.
Parameters
----------
other: Any
Returns
-------
bool
"""
if self.__class__ != other.__class__:
return False
if len(self.losses) != len(other.losses):
return False
loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses))
loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses))
return loss_classes_self == loss_classes_other and self.weights == other.weights
@classmethod
def from_str(cls, loss_func_str: str):
"""
Parse a weighted loss function from a string.
E.g.: 0.1*gdl+0.9*l2
"""
addends = loss_func_str.split("+")
losses, weights = [], []
for addend in loss_func_str.split("+"):
factors = addend.strip().split("*")
if len(factors) == 1:
weights.append(1.0)
losses.append(loss_by_string(factors[0]))
else:
weights.append(float(factors[0]))
losses.append(loss_by_string(factors[1]))
return cls(losses, weights)
import argparse
parser = argparse.ArgumentParser(
description="Test building a loss function from a string"
)
parser.add_argument(
"loss", type=str, help="description of a loss function, e.g., 0.75*L1+0.25*GDL"
)
args = parser.parse_args()
criterion = WeightedLoss.from_str(args.loss)
inputs = torch.Tensor([[1, 2], [3, 4]])
inputs = inputs.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
targets = torch.ones(inputs.shape)
loss = criterion(inputs, targets)
print(f"Inputs:\n{inputs}")
print(f"Targets:\n{targets}")
print(f"Criterion: {criterion}")