Skip to content
Snippets Groups Projects
Commit b08f5d08 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

add early stopping param to cgan training and update doc

parent 77851a54
No related branches found
No related tags found
No related merge requests found
......@@ -17,6 +17,7 @@ class DiscriminatorParams(TrainingParams):
"""
Wrap training parameters to always carry the name 'Discriminator'.
"""
def __init__(
self,
model: torch.nn.Module,
......@@ -30,10 +31,12 @@ class DiscriminatorParams(TrainingParams):
lr_scheduler=lr_scheduler,
)
class GeneratorParams(TrainingParams):
"""
Wrap training parameters to always carry the name 'Generator'.
"""
def __init__(
self,
model: torch.nn.Module,
......@@ -48,11 +51,26 @@ class GeneratorParams(TrainingParams):
)
class cGANTraining(AbstractTraining):
"""
Implementation of a conditional generative adversarial network training.
To see all parameters, have a look at AbstractTraining.
Parameters
----------
params_generator: GeneratorParams
training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
params_discriminator: DiscriminatorParams
training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
loss_func_dist: WeightedLoss
distance loss function for the generator
weight_criterion_dist: float
weight of the distance loss when training the generator
weight_criterion_adv: float
weight of the adversarial loss when training the generator
"""
def __init__(
self,
epochs: int,
......@@ -66,17 +84,18 @@ class cGANTraining(AbstractTraining):
loss_func_dist: WeightedLoss,
weight_criterion_dist: float,
weight_criterion_adv: float,
early_stopping: Optional[int],
logger: Optional[Logger] = None,
):
"""
:param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
:param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
:param loss_func_dist: distance loss function for the generator
:param weight_criterion_dist: weight of the distance loss when training the generator
:param weight_criterion_adv: weight of the adversarial loss when training the generator
"""
super().__init__(
epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger
epochs=epochs,
dataset=dataset,
batch_size=batch_size,
device=device,
early_stopping=early_stopping,
snapshot_dir=snapshot_dir,
snapshot_epoch=snapshot_epoch,
logger=logger,
)
self.training_params.append(params_generator)
self.training_params.append(params_discriminator)
......@@ -102,7 +121,7 @@ class cGANTraining(AbstractTraining):
pass
def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
mu_maps_real = mu_maps # rename real mu maps for clarification
mu_maps_real = mu_maps # rename real mu maps for clarification
# compute fake mu maps with generator
mu_maps_fake = self.generator(recons)
......@@ -111,8 +130,16 @@ class cGANTraining(AbstractTraining):
inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
# prepare labels/targets for the discriminator
labels_fake = torch.full(self.discriminator.get_output_shape(inputs_d_fake.shape), LABEL_FAKE, device=self.device)
labels_real = torch.full(self.discriminator.get_output_shape(inputs_d_real.shape), LABEL_REAL, device=self.device)
labels_fake = torch.full(
self.discriminator.get_output_shape(inputs_d_fake.shape),
LABEL_FAKE,
device=self.device,
)
labels_real = torch.full(
self.discriminator.get_output_shape(inputs_d_real.shape),
LABEL_REAL,
device=self.device,
)
# ======================= Discriminator =====================================
# compute discriminator loss for fake mu maps
......@@ -131,7 +158,7 @@ class cGANTraining(AbstractTraining):
# ===========================================================================
# ======================= Generator =========================================
outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach
outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach
loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
loss_g = (
......@@ -256,6 +283,11 @@ if __name__ == "__main__":
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--early_stopping",
type=int,
help="define early stopping as the least amount of epochs in which the validation loss must improve",
)
parser.add_argument(
"--dist_loss_func",
type=str,
......@@ -400,6 +432,7 @@ if __name__ == "__main__":
dataset=dataset,
batch_size=args.batch_size,
device=device,
early_stopping=args.early_stopping,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
params_generator=params_g,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment