diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py index b62c07efad68fa72b2b9ea8e4a0b88f4e241f28e..51f62ef777c76d989f1cb6667646d7db799fc758 100644 --- a/mu_map/training/loss.py +++ b/mu_map/training/loss.py @@ -1,3 +1,5 @@ +from typing import List + import torch import torch.nn as nn @@ -17,11 +19,97 @@ class GradientDifferenceLoss(nn.Module): return gradient_diff / inputs.numel() +def loss_by_string(loss_str: str) -> nn.Module: + """ + Retrieve a loss function defined by a string. + E.g., L1 returns the torch module of the l1 loss function. + + :param loss_str: loss function defined as a string + :returns: an executable loss function + """ + loss_str = loss_str.lower() + if loss_str == "l1": + return nn.L1Loss(reduction="mean") + elif loss_str == "l2" or loss_str == "mse": + return nn.MSELoss(reduction="mean") + elif loss_str == "gdl": + return GradientDifferenceLoss() + else: + raise ValueError(f"Unknown loss function: {loss_str}") + + +class WeightedLoss(nn.Module): + """ + Definition of a weighted loss consisting of a number of losses + with according weights. + + :param losses: the losses to be summed and weighted + :param weights: weights for each loss function + """ + + def __init__(self, losses: List[nn.Module], weights: List[float]): + super().__init__() + + assert len(losses) == len( + weights + ), f"There is a different number of losses {len(losses)} compared to weights {len(weights)}" + + self.losses = losses + self.weights = weights + + def forward(self, outputs: torch.Tensor, targets: torch.Tensor): + loss = 0.0 + for loss_func, weight in zip(self.losses, self.weights): + loss += weight * loss_func(outputs, targets) + return loss + + def __repr__(self): + return " + ".join( + map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses)) + ) + + @classmethod + def from_str(cls, loss_func_str: str): + """ + Parse a weighted loss function from a string. + E.g.: 0.1*gdl+0.9*l2 + """ + addends = loss_func_str.split("+") + + losses, weights = [], [] + for addend in loss_func_str.split("+"): + factors = addend.strip().split("*") + + if len(factors) == 1: + weights.append(1.0) + losses.append(loss_by_string(factors[0])) + else: + weights.append(float(factors[0])) + losses.append(loss_by_string(factors[1])) + + return cls(losses, weights) + + if __name__ == "__main__": - torch.manual_seed(10) - inputs = torch.rand((4, 1, 32, 64, 64)) - targets = torch.rand((4, 1, 32, 64, 64)) + import argparse + + parser = argparse.ArgumentParser( + description="Test building a loss function from a string" + ) + parser.add_argument( + "loss", type=str, help="description of a loss function, e.g., 0.75*L1+0.25*GDL" + ) + args = parser.parse_args() + + criterion = WeightedLoss.from_str(args.loss) + + inputs = torch.Tensor([[1, 2], [3, 4]]) + inputs = inputs.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0) + targets = torch.ones(inputs.shape) - criterion = GradientDifferenceLoss() loss = criterion(inputs, targets) + + print(f"Inputs:\n{inputs}") + print(f"Targets:\n{targets}") + print(f"Criterion: {criterion}") print(f"Loss: {loss.item():.6f}")