Skip to content
Snippets Groups Projects
Commit e0e94b14 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

make use of the weighted loss function in default training and log floats in more detail

parent de30c17e
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,7 @@ from typing import Dict
import torch
from mu_map.logging import get_logger
from mu_map.training.loss import GradientDifferenceLoss
from mu_map.training.loss import WeightedLoss
class Training:
......@@ -14,6 +14,7 @@ class Training:
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
device: torch.device,
loss_func: WeightedLoss,
lr: float,
lr_decay_factor: float,
lr_decay_epoch: int,
......@@ -39,13 +40,7 @@ class Training:
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
)
# self.loss_func = torch.nn.MSELoss(reduction="mean")
# self.loss_func = torch.nn.L1Loss(reduction="mean")
_loss1 = torch.nn.MSELoss()
_loss2 = GradientDifferenceLoss()
def _loss_func(outputs, targets):
return _loss1(outputs, targets) + _loss2(outputs, targets)
self.loss_func = _loss_func
self.loss_func = loss_func
def run(self):
for epoch in range(1, self.epochs + 1):
......@@ -56,20 +51,19 @@ class Training:
loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}"
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss train: {loss_training:.6f}"
)
loss_validation = self._run_epoch(
self.data_loaders["validation"], phase="val"
)
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}"
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss validation: {loss_validation:.6f}"
)
# ToDo: log outputs and time
_previous = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step()
logger.debug(
f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}"
f"Update learning rate from {_previous:.6f} to {self.lr_scheduler.get_last_lr()[0]:.6f}"
)
if epoch % self.snapshot_epoch == 0:
......@@ -212,6 +206,12 @@ if __name__ == "__main__":
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--loss_func",
type=str,
default="l1",
help="define the loss function used for training, e.g. 0.75*l1+0.25*gdl",
)
parser.add_argument(
"--lr", type=float, default=0.001, help="the initial learning rate for training"
)
......@@ -303,11 +303,15 @@ if __name__ == "__main__":
)
data_loaders[split] = data_loader
criterion = WeightedLoss.from_str(args.loss_func)
logger.debug(f"Criterion: {criterion}")
training = Training(
model=model,
data_loaders=data_loaders,
epochs=args.epochs,
device=device,
loss_func=criterion,
lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_epoch=args.lr_decay_epoch,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment