diff --git a/mu_map/training/distance.py b/mu_map/training/distance.py index 78c94a69e856923e878fdfdc2cd4bac7775ba614..5e29323665a93b07e5ce6a274346f225f81ee4b4 100644 --- a/mu_map/training/distance.py +++ b/mu_map/training/distance.py @@ -12,7 +12,17 @@ class DistanceTraining(AbstractTraining): """ Implementation of a distance training: a model predicts a mu map from a reconstruction by optimizing a distance loss (e.g. L1). + + To see all parameters, have a look at AbstractTraining. + + Parameters + ---------- + params: TrainingParams + training parameters containing a model an according optimizer and optionally a learning rate scheduler + loss_func: WeightedLoss + the distance loss function """ + def __init__( self, epochs: int, @@ -23,13 +33,19 @@ class DistanceTraining(AbstractTraining): snapshot_epoch: int, params: TrainingParams, loss_func: WeightedLoss, + early_stopping: Optional[int] = None, logger: Optional[Logger] = None, ): - """ - :param params: training parameters containing a model an according optimizer and optionally a learning rate scheduler - :param loss_func: the distance loss function - """ - super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger) + super().__init__( + epochs=epochs, + dataset=dataset, + batch_size=batch_size, + device=device, + snapshot_dir=snapshot_dir, + snapshot_epoch=snapshot_epoch, + early_stopping=early_stopping, + logger=logger, + ) self.training_params.append(params) self.loss_func = loss_func @@ -140,6 +156,11 @@ if __name__ == "__main__": default=100, help="the number of epochs for which the model is trained", ) + parser.add_argument( + "--early_stopping", + type=int, + help="define early stopping as the least amount of epochs in which the validation loss must improve", + ) parser.add_argument( "--device", type=str, @@ -216,7 +237,6 @@ if __name__ == "__main__": torch.manual_seed(args.seed) np.random.seed(args.seed) - transform_normalization = None if args.input_norm == "mean": transform_normalization = MeanNormTransform() @@ -244,7 +264,9 @@ if __name__ == "__main__": if args.decay_lr else None ) - params = TrainingParams(name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler) + params = TrainingParams( + name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler + ) criterion = WeightedLoss.from_str(args.loss_func) @@ -257,6 +279,7 @@ if __name__ == "__main__": snapshot_epoch=args.snapshot_epoch, params=params, loss_func=criterion, + early_stopping=args.early_stopping, logger=logger, ) training.run()