Skip to content
Snippets Groups Projects
Commit 249fa333 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add gdl loss to default training

parent 450d9f96
No related branches found
No related tags found
No related merge requests found
...@@ -4,6 +4,7 @@ from typing import Dict ...@@ -4,6 +4,7 @@ from typing import Dict
import torch import torch
from mu_map.logging import get_logger from mu_map.logging import get_logger
from mu_map.training.loss import GradientDifferenceLoss
class Training: class Training:
...@@ -39,7 +40,12 @@ class Training: ...@@ -39,7 +40,12 @@ class Training:
self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
) )
# self.loss_func = torch.nn.MSELoss(reduction="mean") # self.loss_func = torch.nn.MSELoss(reduction="mean")
self.loss_func = torch.nn.L1Loss(reduction="mean") # self.loss_func = torch.nn.L1Loss(reduction="mean")
_loss1 = torch.nn.MSELoss()
_loss2 = GradientDifferenceLoss()
def _loss_func(outputs, targets):
return _loss1(outputs, targets) + _loss2(outputs, targets)
self.loss_func = _loss_func
def run(self): def run(self):
for epoch in range(1, self.epochs + 1): for epoch in range(1, self.epochs + 1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment