Skip to content
Snippets Groups Projects
loss.py 4.46 KiB
Newer Older
  • Learn to ignore specific revisions
  • import torch
    import torch.nn as nn
    
    
    class GradientDifferenceLoss(nn.Module):
        """
        Gradient Difference Loss (GDL) inspired by https://github.com/mmany/pytorch-GDL/blob/main/custom_loss_functions.py.
        It is modified to deal with 5D tensors (batch_size, channels, z, y, x).
        """
    
        def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
            gradient_diff_z = (inputs.diff(dim=2) - targets.diff(axis=2)).pow(2).sum()
            gradient_diff_y = (inputs.diff(dim=3) - targets.diff(axis=3)).pow(2).sum()
            gradient_diff_x = (inputs.diff(dim=4) - targets.diff(axis=4)).pow(2).sum()
    
            gradient_diff = gradient_diff_x + gradient_diff_y + gradient_diff_z
            return gradient_diff / inputs.numel()
    
    
    
    def loss_by_string(loss_str: str) -> nn.Module:
        """
        Retrieve a loss function defined by a string.
        E.g., L1 returns the torch module of the l1 loss function.
    
        :param loss_str: loss function defined as a string
        :returns: an executable loss function
        """
        loss_str = loss_str.lower()
    
            return nn.L1Loss(reduction="mean")
    
        elif "l2" in loss_str or "mse" in loss_str:
    
            return nn.MSELoss(reduction="mean")
    
        elif "gdl" in loss_str or "gradientdifferenceloss" in loss_str:
    
            return GradientDifferenceLoss()
        else:
            raise ValueError(f"Unknown loss function: {loss_str}")
    
    
    class WeightedLoss(nn.Module):
        """
        Definition of a weighted loss consisting of a number of losses
        with according weights.
    
        :param losses: the losses to be summed and weighted
        :param weights: weights for each loss function
        """
    
        def __init__(self, losses: List[nn.Module], weights: List[float]):
            super().__init__()
    
            assert len(losses) == len(
                weights
            ), f"There is a different number of losses {len(losses)} compared to weights {len(weights)}"
    
            self.losses = losses
            self.weights = weights
    
        def forward(self, outputs: torch.Tensor, targets: torch.Tensor):
            loss = 0.0
            for loss_func, weight in zip(self.losses, self.weights):
                loss += weight * loss_func(outputs, targets)
            return loss
    
        def __repr__(self):
            return " + ".join(
                map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
            )
    
    
        def __eq__(self, other: Any) -> bool:
            """
            Implementation of the comparison operator.
    
            This implementation makes sure that both weighted losses consist
            of the same loss classes with the same weights. Note that this
            is not a `smart` implementation as the order matters. For example,
            `L2+GDL` is not equal to `GDL+L2`.
    
            Parameters
            ----------
            other: Any
    
            Returns
            -------
            bool
            """
            if self.__class__ != other.__class__:
                return False
    
            if len(self.losses) != len(other.losses):
                return False
    
            loss_classes_self = tuple(map(lambda loss: loss.__class__, self.losses))
            loss_classes_other = tuple(map(lambda loss: loss.__class__, other.losses))
            return loss_classes_self == loss_classes_other and self.weights == other.weights
    
    
        @classmethod
        def from_str(cls, loss_func_str: str):
            """
            Parse a weighted loss function from a string.
            E.g.: 0.1*gdl+0.9*l2
            """
            addends = loss_func_str.split("+")
    
            losses, weights = [], []
            for addend in loss_func_str.split("+"):
                factors = addend.strip().split("*")
    
                if len(factors) == 1:
                    weights.append(1.0)
                    losses.append(loss_by_string(factors[0]))
                else:
                    weights.append(float(factors[0]))
                    losses.append(loss_by_string(factors[1]))
    
            return cls(losses, weights)
    
    
    
    if __name__ == "__main__":
    
        import argparse
    
        parser = argparse.ArgumentParser(
            description="Test building a loss function from a string"
        )
        parser.add_argument(
            "loss", type=str, help="description of a loss function, e.g., 0.75*L1+0.25*GDL"
        )
        args = parser.parse_args()
    
        criterion = WeightedLoss.from_str(args.loss)
    
        inputs = torch.Tensor([[1, 2], [3, 4]])
        inputs = inputs.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
        targets = torch.ones(inputs.shape)
    
    
        loss = criterion(inputs, targets)
    
    
        print(f"Inputs:\n{inputs}")
        print(f"Targets:\n{targets}")
        print(f"Criterion: {criterion}")
    
        print(f"Loss: {loss.item():.6f}")