Skip to content
Snippets Groups Projects
Commit 5d9d695e authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

update distance training to use early stopping

parent b08f5d08
No related branches found
No related tags found
No related merge requests found
...@@ -12,7 +12,17 @@ class DistanceTraining(AbstractTraining): ...@@ -12,7 +12,17 @@ class DistanceTraining(AbstractTraining):
""" """
Implementation of a distance training: a model predicts a mu map Implementation of a distance training: a model predicts a mu map
from a reconstruction by optimizing a distance loss (e.g. L1). from a reconstruction by optimizing a distance loss (e.g. L1).
To see all parameters, have a look at AbstractTraining.
Parameters
----------
params: TrainingParams
training parameters containing a model an according optimizer and optionally a learning rate scheduler
loss_func: WeightedLoss
the distance loss function
""" """
def __init__( def __init__(
self, self,
epochs: int, epochs: int,
...@@ -23,13 +33,19 @@ class DistanceTraining(AbstractTraining): ...@@ -23,13 +33,19 @@ class DistanceTraining(AbstractTraining):
snapshot_epoch: int, snapshot_epoch: int,
params: TrainingParams, params: TrainingParams,
loss_func: WeightedLoss, loss_func: WeightedLoss,
early_stopping: Optional[int] = None,
logger: Optional[Logger] = None, logger: Optional[Logger] = None,
): ):
""" super().__init__(
:param params: training parameters containing a model an according optimizer and optionally a learning rate scheduler epochs=epochs,
:param loss_func: the distance loss function dataset=dataset,
""" batch_size=batch_size,
super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger) device=device,
snapshot_dir=snapshot_dir,
snapshot_epoch=snapshot_epoch,
early_stopping=early_stopping,
logger=logger,
)
self.training_params.append(params) self.training_params.append(params)
self.loss_func = loss_func self.loss_func = loss_func
...@@ -140,6 +156,11 @@ if __name__ == "__main__": ...@@ -140,6 +156,11 @@ if __name__ == "__main__":
default=100, default=100,
help="the number of epochs for which the model is trained", help="the number of epochs for which the model is trained",
) )
parser.add_argument(
"--early_stopping",
type=int,
help="define early stopping as the least amount of epochs in which the validation loss must improve",
)
parser.add_argument( parser.add_argument(
"--device", "--device",
type=str, type=str,
...@@ -216,7 +237,6 @@ if __name__ == "__main__": ...@@ -216,7 +237,6 @@ if __name__ == "__main__":
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
np.random.seed(args.seed) np.random.seed(args.seed)
transform_normalization = None transform_normalization = None
if args.input_norm == "mean": if args.input_norm == "mean":
transform_normalization = MeanNormTransform() transform_normalization = MeanNormTransform()
...@@ -244,7 +264,9 @@ if __name__ == "__main__": ...@@ -244,7 +264,9 @@ if __name__ == "__main__":
if args.decay_lr if args.decay_lr
else None else None
) )
params = TrainingParams(name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler) params = TrainingParams(
name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler
)
criterion = WeightedLoss.from_str(args.loss_func) criterion = WeightedLoss.from_str(args.loss_func)
...@@ -257,6 +279,7 @@ if __name__ == "__main__": ...@@ -257,6 +279,7 @@ if __name__ == "__main__":
snapshot_epoch=args.snapshot_epoch, snapshot_epoch=args.snapshot_epoch,
params=params, params=params,
loss_func=criterion, loss_func=criterion,
early_stopping=args.early_stopping,
logger=logger, logger=logger,
) )
training.run() training.run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment