Skip to content
Snippets Groups Projects
Commit 249fa333 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add gdl loss to default training

parent 450d9f96
No related merge requests found
......@@ -4,6 +4,7 @@ from typing import Dict
import torch
from mu_map.logging import get_logger
from mu_map.training.loss import GradientDifferenceLoss
class Training:
......@@ -39,7 +40,12 @@ class Training:
self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
)
# self.loss_func = torch.nn.MSELoss(reduction="mean")
self.loss_func = torch.nn.L1Loss(reduction="mean")
# self.loss_func = torch.nn.L1Loss(reduction="mean")
_loss1 = torch.nn.MSELoss()
_loss2 = GradientDifferenceLoss()
def _loss_func(outputs, targets):
return _loss1(outputs, targets) + _loss2(outputs, targets)
self.loss_func = _loss_func
def run(self):
for epoch in range(1, self.epochs + 1):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment