import os from typing import Dict import torch from torch import Tensor from mu_map.training.loss import GradientDifferenceLoss from mu_map.logging import get_logger # Establish convention for real and fake labels during training LABEL_REAL = 1.0 LABEL_FAKE = 0.0 # class GeneratorLoss(torch.nn.Module): # def __init__( # self, # # l2_weight: float = 1.0, # # gdl_weight: float = 1.0, # # adv_weight: float = 20.0, # # logger=None, # ): # super().__init__() # # self.l2 = torch.nn.MSELoss(reduction="mean") # self.l2 = torch.nn.L1Loss(reduction="mean") # self.l2_weight = l2_weight # self.gdl = GradientDifferenceLoss() # self.gdl_weight = gdl_weight # self.adv = torch.nn.MSELoss(reduction="mean") # self.adv_weight = adv_weight # if logger: # logger.debug(f"GeneratorLoss: {self}") # def __repr__(self): # return f"{self.l2_weight:.3f} * MSELoss + {self.gdl_weight:.3f} * GDLLoss + {self.adv_weight:.3f} * AdversarialLoss" # def forward( # self, # mu_maps_real: Tensor, # outputs_g: Tensor, # targets_d: Tensor, # outputs_d: Tensor, # ): # loss_l2 = self.l2(outputs_g, mu_maps_real) # loss_gdl = self.gdl(outputs_g, mu_maps_real) # loss_adv = self.adv(outputs_d, targets_d) # return ( # self.l2_weight * loss_l2 # + self.gdl_weight * loss_gdl # + self.adv_weight * loss_adv # ) class cGANTraining: def __init__( self, generator: torch.nn.Module, discriminator: torch.nn.Module, data_loaders: Dict[str, torch.utils.data.DataLoader], epochs: int, device: torch.device, lr_d: float, lr_decay_factor_d: float, lr_decay_epoch_d: int, lr_g: float, lr_decay_factor_g: float, lr_decay_epoch_g: int, l2_weight: float, gdl_weight: float, adv_weight: float, snapshot_dir: str, snapshot_epoch: int, logger=None, ): self.generator = generator self.discriminator = discriminator self.data_loaders = data_loaders self.epochs = epochs self.device = device self.snapshot_dir = snapshot_dir self.snapshot_epoch = snapshot_epoch self.logger = logger if logger is not None else get_logger() self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999)) self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g, betas=(0.5, 0.999)) # self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR( # self.optimizer_d, # step_size=lr_decay_epoch_d, # gamma=lr_decay_factor_d, # ) # self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR( # self.optimizer_g, # step_size=lr_decay_epoch_g, # gamma=lr_decay_factor_g, # ) self.criterion_d = torch.nn.MSELoss(reduction="mean") # self.criterion_g = GeneratorLoss( # l2_weight=l2_weight, # gdl_weight=gdl_weight, # adv_weight=adv_weight, # logger=self.logger, # ) self.criterion_l1 = torch.nn.L1Loss(reduction="mean") def run(self): losses_d = [] losses_g = [] for epoch in range(1, self.epochs + 1): logger.debug( f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..." ) _losses_d, _losses_g = self._train_epoch() losses_d.extend(_losses_d) losses_g.extend(_losses_g) self._eval_epoch(epoch, "train") self._eval_epoch(epoch, "validation") # self.lr_scheduler_d.step() # self.lr_scheduler_g.step() if epoch % self.snapshot_epoch == 0: self.store_snapshot(epoch) logger.debug( f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}" ) return losses_d, losses_g def _train_epoch(self): logger.debug(f"Train epoch") torch.set_grad_enabled(True) self.discriminator = self.discriminator.train() self.generator = self.generator.train() losses_d = [] losses_g = [] data_loader = self.data_loaders["train"] for i, (recons, mu_maps) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) recons = recons.to(self.device) mu_maps = mu_maps.to(self.device) loss_d_real, loss_d_fake, loss_g = self._step(recons, mu_maps) losses_d.append(loss_d_real + loss_d_fake) losses_g.append(loss_g) return losses_d, losses_g def _step(self, recons, mu_maps_real): batch_size = recons.shape[0] with torch.set_grad_enabled(True): self.optimizer_d.zero_grad() self.optimizer_g.zero_grad() # compute fake mu maps with generator mu_maps_fake = self.generator(recons) # compute discriminator loss for fake mu maps inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) outputs_d_fake = self.discriminator(inputs_d_fake.detach()) # note the detach, so that gradients are not computed for the generator labels_fake = torch.full((outputs_d_fake.shape), LABEL_FAKE, device=self.device) loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake) # compute discriminator loss for real mu maps inputs_d_real = torch.cat((recons, mu_maps_real), dim=1) outputs_d_real = self.discriminator(inputs_d_real) # note the detach, so that gradients are not computed for the generator labels_real = torch.full((outputs_d_fake.shape), LABEL_REAL, device=self.device) loss_d_real = self.criterion_d(outputs_d_real, labels_real) # update discriminator loss_d = 0.5 * (loss_d_fake + loss_d_real) loss_d.backward() # compute gradients self.optimizer_d.step() # update generator inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) outputs_d_fake = self.discriminator(inputs_d_fake) loss_g_adv = self.criterion_d(outputs_d_fake, labels_real) loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real) loss_g = loss_g_adv + 100.0 * loss_g_l1 loss_g.backward() self.optimizer_g.step() return loss_d_real.item(), loss_d_fake.item(), loss_g.item() def _eval_epoch(self, epoch, split_name): logger.debug(f"Evaluate epoch on split {split_name}") torch.set_grad_enabled(False) self.discriminator = self.discriminator.eval() self.generator = self.generator.eval() loss = 0.0 updates = 0 data_loader = self.data_loaders[split_name] for i, (recons, mu_maps) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) recons = recons.to(self.device) mu_maps = mu_maps.to(self.device) outputs = self.generator(recons) loss += torch.nn.functional.l1_loss(outputs, mu_maps) updates += 1 loss = loss / updates logger.info( f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss {split_name}: {loss:.6f}" ) def store_snapshot(self, epoch): snapshot_file_d = f"{epoch:0{len(str(self.epochs))}d}_discriminator.pth" snapshot_file_d = os.path.join(self.snapshot_dir, snapshot_file_d) snapshot_file_g = f"{epoch:0{len(str(self.epochs))}d}_generator.pth" snapshot_file_g = os.path.join(self.snapshot_dir, snapshot_file_g) logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}") torch.save(self.discriminator.state_dict(), snapshot_file_d) torch.save(self.generator.state_dict(), snapshot_file_g) if __name__ == "__main__": import argparse import random import sys import numpy as np from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.normalization import ( MeanNormTransform, MaxNormTransform, GaussianNormTransform, ) from mu_map.dataset.transform import ScaleTransform from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet from mu_map.models.discriminator import Discriminator, PatchDiscriminator parser = argparse.ArgumentParser( description="Train a UNet model to predict μ-maps from reconstructed scatter images", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Model Args parser.add_argument( "--features", type=int, nargs="+", default=[64, 128, 256, 512], help="number of features in the layers of the UNet structure", ) # Dataset Args parser.add_argument( "--dataset_dir", type=str, default="data/initial/", help="the directory where the dataset for training is found", ) parser.add_argument( "--input_norm", type=str, choices=["none", "mean", "max", "gaussian"], default="mean", help="type of normalization applied to the reconstructions", ) parser.add_argument( "--patch_size", type=int, default=32, help="the size of patches extracted for each reconstruction", ) parser.add_argument( "--patch_offset", type=int, default=20, help="offset to ignore the border of the image", ) parser.add_argument( "--number_of_patches", type=int, default=100, help="number of patches extracted for each image", ) parser.add_argument( "--no_shuffle", action="store_true", help="do not shuffle patches in the dataset", ) # Training Args parser.add_argument( "--seed", type=int, help="seed used for random number generation", ) parser.add_argument( "--batch_size", type=int, default=64, help="the batch size used for training", ) parser.add_argument( "--output_dir", type=str, default="train_data", help="directory in which results (snapshots and logs) of this training are saved", ) parser.add_argument( "--epochs", type=int, default=100, help="the number of epochs for which the model is trained", ) parser.add_argument( "--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="the device (cpu or gpu) with which the training is performed", ) parser.add_argument( "--mse_loss_weight", type=float, default=1.0, help="weight for the L2-Loss of the generator", ) parser.add_argument( "--gdl_loss_weight", type=float, default=1.0, help="weight for the Gradient-Difference-Loss of the generator", ) parser.add_argument( "--adv_loss_weight", type=float, default=20.0, help="weight for the Adversarial-Loss of the generator", ) parser.add_argument( "--lr", type=float, default=0.001, help="the initial learning rate for training" ) parser.add_argument( "--lr_decay_factor", type=float, default=0.99, help="decay factor for the learning rate", ) parser.add_argument( "--lr_decay_epoch", type=int, default=1, help="frequency in epochs at which the learning rate is decayed", ) parser.add_argument( "--snapshot_dir", type=str, default="snapshots", help="directory under --output_dir where snapshots are stored", ) parser.add_argument( "--snapshot_epoch", type=int, default=10, help="frequency in epochs at which snapshots are stored", ) parser.add_argument( "--generator_weights", type=str, help="use pre-trained weights for the generator", ) # Logging Args add_logging_args(parser, defaults={"--logfile": "train.log"}) args = parser.parse_args() if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) if not os.path.exists(args.snapshot_dir): os.mkdir(args.snapshot_dir) else: if len(os.listdir(args.snapshot_dir)) > 0: print( f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!" ) print(f" Exit so that data is not accidentally overwritten!") exit(1) args.logfile = os.path.join(args.output_dir, args.logfile) device = torch.device(args.device) logger = get_logger_by_args(args) logger.info(args) args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {args.seed}") random.seed(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) # discriminator = Discriminator(in_channels=2, input_size=args.patch_size) discriminator = PatchDiscriminator(in_channels=2, input_size=args.patch_size) discriminator = discriminator.to(device) generator = UNet(in_channels=1, features=args.features) generator = generator.to(device) if args.generator_weights: logger.debug(f"Load generator weights from {args.generator_weights}") generator.load_state_dict(torch.load(args.generator_weights, map_location=device)) transform_normalization = None if args.input_norm == "mean": transform_normalization = MeanNormTransform() elif args.input_norm == "max": transform_normalization = MaxNormTransform() elif args.input_norm == "gaussian": transform_normalization = GaussianNormTransform() data_loaders = {} for split in ["train", "validation"]: dataset = MuMapPatchDataset( args.dataset_dir, patches_per_image=args.number_of_patches, patch_size=args.patch_size, patch_offset=args.patch_offset, shuffle=not args.no_shuffle, split_name=split, transform_normalization=transform_normalization, logger=logger, ) data_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=1, ) data_loaders[split] = data_loader training = cGANTraining( discriminator=discriminator, generator=generator, data_loaders=data_loaders, epochs=args.epochs, device=device, lr_d=0.0002, lr_decay_factor_d=0.99, lr_decay_epoch_d=1, lr_g=0.0002, lr_decay_factor_g=0.99, lr_decay_epoch_g=1, l2_weight=args.mse_loss_weight, gdl_weight=args.gdl_loss_weight, adv_weight=args.adv_loss_weight, snapshot_dir=args.snapshot_dir, snapshot_epoch=args.snapshot_epoch, logger=logger, ) losses_d, losses_g = training.run() import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2, figsize=(10, 5)) axs[0].plot(losses_d) axs[0].set_title("Discriminator Loss") axs[0].set_xlabel("Iteration") axs[0].set_ylabel("Loss") axs[1].plot(losses_g, label="Generator") axs[1].set_title("Generator Loss") axs[1].set_xlabel("Iteration") axs[1].set_ylabel("Loss") plt.savefig("losses.png")