Skip to content
Snippets Groups Projects
Commit afdee20b authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

make loss weights in cgan loss configurable

parent e0e94b14
No related branches found
No related tags found
No related merge requests found
......@@ -14,7 +14,11 @@ LABEL_FAKE = 0.0
class GeneratorLoss(torch.nn.Module):
def __init__(
self, l2_weight: float = 1.0, gdl_weight: float = 1.0, adv_weight: float = 20.0
self,
l2_weight: float = 1.0,
gdl_weight: float = 1.0,
adv_weight: float = 20.0,
logger=None,
):
super().__init__()
......@@ -27,6 +31,12 @@ class GeneratorLoss(torch.nn.Module):
self.adv = torch.nn.MSELoss(reduction="mean")
self.adv_weight = adv_weight
if logger:
logger.debug(f"GeneratorLoss: {self}")
def __repr__(self):
return f"{self.l2_weight:.3f} * MSELoss + {self.gdl_weight:.3f} * GDLLoss + {self.adv_weight:.3f} * AdversarialLoss"
def forward(
self,
mu_maps_real: Tensor,
......@@ -94,7 +104,10 @@ class cGANTraining:
self.criterion_d = torch.nn.MSELoss(reduction="mean")
self.criterion_g = GeneratorLoss(
l2_weight=l2_weight, gdl_weight=gdl_weight, adv_weight=adv_weight
l2_weight=l2_weight,
gdl_weight=gdl_weight,
adv_weight=adv_weight,
logger=self.logger,
)
def run(self):
......@@ -174,7 +187,9 @@ class cGANTraining:
# update generator
outputs_d = self.discriminator(mu_maps_fake)
loss_g = self.criterion_g(mu_maps_real, mu_maps_fake, labels_real, outputs_d)
loss_g = self.criterion_g(
mu_maps_real, mu_maps_fake, labels_real, outputs_d
)
loss_g.backward()
self.optimizer_g.step()
......@@ -189,7 +204,7 @@ class cGANTraining:
loss = 0.0
updates = 0
data_loader = self.data_loaders[split_name]
for i, (recons, mu_maps) in enumerate(data_loader):
print(
......@@ -319,6 +334,24 @@ if __name__ == "__main__":
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--mse_loss_weight",
type=float,
default=1.0,
help="weight for the L2-Loss of the generator",
)
parser.add_argument(
"--gdl_loss_weight",
type=float,
default=1.0,
help="weight for the Gradient-Difference-Loss of the generator",
)
parser.add_argument(
"--adv_loss_weight",
type=float,
default=20.0,
help="weight for the Adversarial-Loss of the generator",
)
parser.add_argument(
"--lr", type=float, default=0.001, help="the initial learning rate for training"
)
......@@ -372,7 +405,7 @@ if __name__ == "__main__":
logger = get_logger_by_args(args)
logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1)
args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
......@@ -425,9 +458,9 @@ if __name__ == "__main__":
lr_g=0.001,
lr_decay_factor_g=0.99,
lr_decay_epoch_g=1,
l2_weight=0.25,
gdl_weight=0.25,
adv_weight=0.5,
l2_weight=args.mse_loss_weight,
gdl_weight=args.gdl_loss_weight,
adv_weight=args.adv_loss_weight,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
logger=logger,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment