Skip to content
Snippets Groups Projects
Commit e3e6d7d5 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

documen training implementations

parent 2e31fc27
No related branches found
No related tags found
No related merge requests found
......@@ -68,6 +68,13 @@ class cGANTraining(AbstractTraining):
weight_criterion_adv: float,
logger: Optional[Logger] = None,
):
"""
:param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
:param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
:param loss_func_dist: distance loss function for the generator
:param weight_criterion_dist: weight of the distance loss when training the generator
:param weight_criterion_adv: weight of the adversarial loss when training the generator
"""
super().__init__(
epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger
)
......
......@@ -9,6 +9,10 @@ from mu_map.training.loss import WeightedLoss
class DistanceTraining(AbstractTraining):
"""
Implementation of a distance training: a model predicts a mu map
from a reconstruction by optimizing a distance loss (e.g. L1).
"""
def __init__(
self,
epochs: int,
......@@ -21,6 +25,10 @@ class DistanceTraining(AbstractTraining):
loss_func: WeightedLoss,
logger: Optional[Logger] = None,
):
"""
:param params: training parameters containing a model an according optimizer and optionally a learning rate scheduler
:param loss_func: the distance loss function
"""
super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger)
self.training_params.append(params)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment