Skip to content
Snippets Groups Projects
Commit de30c17e authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement parsing a weighted loss function form a string for CLI usage

parent ce61a834
No related branches found
No related tags found
No related merge requests found
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -17,11 +19,97 @@ class GradientDifferenceLoss(nn.Module): ...@@ -17,11 +19,97 @@ class GradientDifferenceLoss(nn.Module):
return gradient_diff / inputs.numel() return gradient_diff / inputs.numel()
def loss_by_string(loss_str: str) -> nn.Module:
"""
Retrieve a loss function defined by a string.
E.g., L1 returns the torch module of the l1 loss function.
:param loss_str: loss function defined as a string
:returns: an executable loss function
"""
loss_str = loss_str.lower()
if loss_str == "l1":
return nn.L1Loss(reduction="mean")
elif loss_str == "l2" or loss_str == "mse":
return nn.MSELoss(reduction="mean")
elif loss_str == "gdl":
return GradientDifferenceLoss()
else:
raise ValueError(f"Unknown loss function: {loss_str}")
class WeightedLoss(nn.Module):
"""
Definition of a weighted loss consisting of a number of losses
with according weights.
:param losses: the losses to be summed and weighted
:param weights: weights for each loss function
"""
def __init__(self, losses: List[nn.Module], weights: List[float]):
super().__init__()
assert len(losses) == len(
weights
), f"There is a different number of losses {len(losses)} compared to weights {len(weights)}"
self.losses = losses
self.weights = weights
def forward(self, outputs: torch.Tensor, targets: torch.Tensor):
loss = 0.0
for loss_func, weight in zip(self.losses, self.weights):
loss += weight * loss_func(outputs, targets)
return loss
def __repr__(self):
return " + ".join(
map(lambda x: f"{x[0]:.3f} * {x[1]}", zip(self.weights, self.losses))
)
@classmethod
def from_str(cls, loss_func_str: str):
"""
Parse a weighted loss function from a string.
E.g.: 0.1*gdl+0.9*l2
"""
addends = loss_func_str.split("+")
losses, weights = [], []
for addend in loss_func_str.split("+"):
factors = addend.strip().split("*")
if len(factors) == 1:
weights.append(1.0)
losses.append(loss_by_string(factors[0]))
else:
weights.append(float(factors[0]))
losses.append(loss_by_string(factors[1]))
return cls(losses, weights)
if __name__ == "__main__": if __name__ == "__main__":
torch.manual_seed(10) import argparse
inputs = torch.rand((4, 1, 32, 64, 64))
targets = torch.rand((4, 1, 32, 64, 64)) parser = argparse.ArgumentParser(
description="Test building a loss function from a string"
)
parser.add_argument(
"loss", type=str, help="description of a loss function, e.g., 0.75*L1+0.25*GDL"
)
args = parser.parse_args()
criterion = WeightedLoss.from_str(args.loss)
inputs = torch.Tensor([[1, 2], [3, 4]])
inputs = inputs.unsqueeze(dim=0).unsqueeze(dim=0).unsqueeze(dim=0)
targets = torch.ones(inputs.shape)
criterion = GradientDifferenceLoss()
loss = criterion(inputs, targets) loss = criterion(inputs, targets)
print(f"Inputs:\n{inputs}")
print(f"Targets:\n{targets}")
print(f"Criterion: {criterion}")
print(f"Loss: {loss.item():.6f}") print(f"Loss: {loss.item():.6f}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment