@@ -17,6 +17,7 @@ class DiscriminatorParams(TrainingParams):
"""
Wrap training parameters to always carry the name 'Discriminator'.
"""
def__init__(
self,
model:torch.nn.Module,
...
...
@@ -30,10 +31,12 @@ class DiscriminatorParams(TrainingParams):
lr_scheduler=lr_scheduler,
)
classGeneratorParams(TrainingParams):
"""
Wrap training parameters to always carry the name 'Generator'.
"""
def__init__(
self,
model:torch.nn.Module,
...
...
@@ -48,11 +51,26 @@ class GeneratorParams(TrainingParams):
)
classcGANTraining(AbstractTraining):
"""
Implementation of a conditional generative adversarial network training.
To see all parameters, have a look at AbstractTraining.
Parameters
----------
params_generator: GeneratorParams
training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
params_discriminator: DiscriminatorParams
training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
loss_func_dist: WeightedLoss
distance loss function for the generator
weight_criterion_dist: float
weight of the distance loss when training the generator
weight_criterion_adv: float
weight of the adversarial loss when training the generator
"""
def__init__(
self,
epochs:int,
...
...
@@ -66,17 +84,18 @@ class cGANTraining(AbstractTraining):
loss_func_dist:WeightedLoss,
weight_criterion_dist:float,
weight_criterion_adv:float,
early_stopping:Optional[int],
logger:Optional[Logger]=None,
):
"""
:param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
:param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
:param loss_func_dist: distance loss function for the generator
:param weight_criterion_dist: weight of the distance loss when training the generator
:param weight_criterion_adv: weight of the adversarial loss when training the generator