:param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
:param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
:param loss_func_dist: distance loss function for the generator
:param weight_criterion_dist: weight of the distance loss when training the generator
:param weight_criterion_adv: weight of the adversarial loss when training the generator
Implementation of a conditional generative adversarial network training.
"""
def__init__(
self,
epochs:int,
dataset:MuMapDataset,
batch_size:int,
device:torch.device,
snapshot_dir:str,
snapshot_epoch:int,
params_generator:GeneratorParams,
params_discriminator:DiscriminatorParams,
loss_func_dist:WeightedLoss,
weight_criterion_dist:float,
weight_criterion_adv:float,
logger:Optional[Logger]=None,
):
"""
:param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
:param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
:param loss_func_dist: distance loss function for the generator
:param weight_criterion_dist: weight of the distance loss when training the generator
:param weight_criterion_adv: weight of the adversarial loss when training the generator