diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index 3aa83ec2d64990ffe46b7e5593402517648b1cff..f970f8841b8aeefed63f064e60e74ab6fe99c921 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -12,48 +12,48 @@ LABEL_REAL = 1.0 LABEL_FAKE = 0.0 -class GeneratorLoss(torch.nn.Module): - def __init__( - self, - # l2_weight: float = 1.0, - # gdl_weight: float = 1.0, - # adv_weight: float = 20.0, - # logger=None, - ): - super().__init__() - - # self.l2 = torch.nn.MSELoss(reduction="mean") - self.l2 = torch.nn.L1Loss(reduction="mean") - self.l2_weight = l2_weight - - self.gdl = GradientDifferenceLoss() - self.gdl_weight = gdl_weight - - self.adv = torch.nn.MSELoss(reduction="mean") - self.adv_weight = adv_weight - - if logger: - logger.debug(f"GeneratorLoss: {self}") - - def __repr__(self): - return f"{self.l2_weight:.3f} * MSELoss + {self.gdl_weight:.3f} * GDLLoss + {self.adv_weight:.3f} * AdversarialLoss" - - def forward( - self, - mu_maps_real: Tensor, - outputs_g: Tensor, - targets_d: Tensor, - outputs_d: Tensor, - ): - loss_l2 = self.l2(outputs_g, mu_maps_real) - loss_gdl = self.gdl(outputs_g, mu_maps_real) - loss_adv = self.adv(outputs_d, targets_d) - - return ( - self.l2_weight * loss_l2 - + self.gdl_weight * loss_gdl - + self.adv_weight * loss_adv - ) +# class GeneratorLoss(torch.nn.Module): + # def __init__( + # self, + # # l2_weight: float = 1.0, + # # gdl_weight: float = 1.0, + # # adv_weight: float = 20.0, + # # logger=None, + # ): + # super().__init__() + + # # self.l2 = torch.nn.MSELoss(reduction="mean") + # self.l2 = torch.nn.L1Loss(reduction="mean") + # self.l2_weight = l2_weight + + # self.gdl = GradientDifferenceLoss() + # self.gdl_weight = gdl_weight + + # self.adv = torch.nn.MSELoss(reduction="mean") + # self.adv_weight = adv_weight + + # if logger: + # logger.debug(f"GeneratorLoss: {self}") + + # def __repr__(self): + # return f"{self.l2_weight:.3f} * MSELoss + {self.gdl_weight:.3f} * GDLLoss + {self.adv_weight:.3f} * AdversarialLoss" + + # def forward( + # self, + # mu_maps_real: Tensor, + # outputs_g: Tensor, + # targets_d: Tensor, + # outputs_d: Tensor, + # ): + # loss_l2 = self.l2(outputs_g, mu_maps_real) + # loss_gdl = self.gdl(outputs_g, mu_maps_real) + # loss_adv = self.adv(outputs_d, targets_d) + + # return ( + # self.l2_weight * loss_l2 + # + self.gdl_weight * loss_gdl + # + self.adv_weight * loss_adv + # ) class cGANTraining: @@ -104,12 +104,12 @@ class cGANTraining: # ) self.criterion_d = torch.nn.MSELoss(reduction="mean") - self.criterion_g = GeneratorLoss( - l2_weight=l2_weight, - gdl_weight=gdl_weight, - adv_weight=adv_weight, - logger=self.logger, - ) + # self.criterion_g = GeneratorLoss( + # l2_weight=l2_weight, + # gdl_weight=gdl_weight, + # adv_weight=adv_weight, + # logger=self.logger, + # ) self.criterion_l1 = torch.nn.L1Loss(reduction="mean") def run(self): @@ -180,7 +180,7 @@ class cGANTraining: loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake) # compute discriminator loss for real mu maps - inputs_d_real = torch.cat((recons, mu_maps), dim=1) + inputs_d_real = torch.cat((recons, mu_maps_real), dim=1) outputs_d_real = self.discriminator(inputs_d_real) # note the detach, so that gradients are not computed for the generator loss_d_real = self.criterion_d(outputs_d_real, labels_real)