diff --git a/mu_map/training/default.py b/mu_map/training/default.py index 1b24b4cb8c760afb5126ff2923cfbbf31dd3cd1e..fe8aa382f1bb32503f93eee4b947bdb1dd3e3333 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -38,7 +38,8 @@ class Training: self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor ) - self.loss_func = torch.nn.MSELoss(reduction="mean") + # self.loss_func = torch.nn.MSELoss(reduction="mean") + self.loss_func = torch.nn.L1Loss(reduction="mean") def run(self): for epoch in range(1, self.epochs + 1):