from typing import Any, List import torch import torch.nn as nn class GradientDifferenceLoss(nn.Module): """ Gradient Difference Loss (GDL) inspired by https://github.com/mmany/pytorch-GDL/blob/main/custom_loss_functions.py. It is modified to deal with 5D tensors (batch_size, channels, z, y, x). """ def forward(self, inputs: torch.Tensor, targets: torch.Tensor): gradient_diff_z = (inputs.diff(dim=2) - targets.diff(axis=2)).pow(2).sum() gradient_diff_y = (inputs.diff(dim=3) - targets.diff(axis=3)).pow(2).sum() gradient_diff_x = (inputs.diff(dim=4) - targets.diff(axis=4)).pow(2).sum() gradient_diff = gradient_diff_x + gradient_diff_y + gradient_diff_z return gradient_diff / inputs.numel() def loss_by_string(loss_str: str) -> nn.Module: """ Retrieve a loss function defined by a string. E.g., L1 returns the torch module of the l1 loss function. :param loss_str: loss function defined as a string :returns: an executable loss function """ loss_str = loss_str.lower() if "l1" in loss_str: return nn.L1Loss(reduction="mean") elif "l2" in loss_str or "mse" in loss_str: return nn.MSELoss(reduction="mean") elif "gdl" in loss_str or "gradientdifferenceloss" in loss_str: return GradientDifferenceLoss() else: raise ValueError(f"Unknown loss function: {loss_str}") class WeightedLoss(nn.Module): """ Definition of a weighted loss consisting of a number of losses with according weights. :param losses: the losses to be summed and weighted :param weights: weights for each loss function """ def __init__(self, losses: List[nn.Module], weights: List[float]): super().__init__() assert len(losses) == len( weights ), f"There is a different number of losses {len(losses)} compared to weights {len(weights)}" self.losses = losses self.weights = weights def forward(self, outputs: torch.Tensor, targets: torch.Tensor): loss = 0.0 for loss_func, weight in zip(self.losses, self.weights): loss += weight * loss_func(outputs, targets) return loss def __repr__(self): return " + ".join( map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses)) ) def __eq__(self, other: Any) -> bool: """ Implementation of the comparison operator. This implementation makes sure that both weighted losses consist of the same loss classes with the same weights. Note that this is not a `smart` implementation as the order matters. For example, `L2+GDL` is not equal to `GDL+L2`. Parameters ---------- other: Any Returns ------- bool """ if self.__class__ != other.__class__: return False if len(self.losses) != len(other.losses): return False loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses)) loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses)) return loss_classes_self == loss_classes_other and self.weights == other.weights @classmethod def from_str(cls, loss_func_str: str): """ Parse a weighted loss function from a string. E.g.: 0.1*gdl+0.9*l2 """ addends = loss_func_str.split("+") losses, weights = [], [] for addend in loss_func_str.split("+"): factors = addend.strip().split("*") if len(factors) == 1: weights.append(1.0) losses.append(loss_by_string(factors[0])) else: weights.append(float(factors[0])) losses.append(loss_by_string(factors[1])) return cls(losses, weights) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser( description="Test building a loss function from a string" ) parser.add_argument( "loss", type=str, help="description of a loss function, e.g., 0.75*L1+0.25*GDL" ) args = parser.parse_args() criterion = WeightedLoss.from_str(args.loss) inputs = torch.Tensor([[1, 2], [3, 4]]) inputs = inputs.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0) targets = torch.ones(inputs.shape) loss = criterion(inputs, targets) print(f"Inputs:\n{inputs}") print(f"Targets:\n{targets}") print(f"Criterion: {criterion}") print(f"Loss: {loss.item():.6f}")