from logging import Logger from typing import Optional import torch from mu_map.dataset.default import MuMapDataset from mu_map.training.lib import TrainingParams, AbstractTraining from mu_map.training.loss import WeightedLoss # Establish convention for real and fake labels during training LABEL_REAL = 1.0 LABEL_FAKE = 0.0 class DiscriminatorParams(TrainingParams): """ Wrap training parameters to always carry the name 'Discriminator'. """ def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], ): super().__init__( name="Discriminator", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, ) class GeneratorParams(TrainingParams): """ Wrap training parameters to always carry the name 'Generator'. """ def __init__( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], ): super().__init__( name="Generator", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, ) class cGANTraining(AbstractTraining): """ Implementation of a conditional generative adversarial network training. """ def __init__( self, epochs: int, dataset: MuMapDataset, batch_size: int, device: torch.device, snapshot_dir: str, snapshot_epoch: int, params_generator: GeneratorParams, params_discriminator: DiscriminatorParams, loss_func_dist: WeightedLoss, weight_criterion_dist: float, weight_criterion_adv: float, logger: Optional[Logger] = None, ): super().__init__( epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger ) self.training_params.append(params_generator) self.training_params.append(params_discriminator) self.generator = params_generator.model self.discriminator = params_discriminator.model self.optim_g = params_generator.optimizer self.optim_d = params_discriminator.optimizer self.weight_criterion_dist = weight_criterion_dist self.weight_criterion_adv = weight_criterion_adv self.criterion_adv = torch.nn.MSELoss(reduction="mean") self.criterion_dist = loss_func_dist def _after_train_batch(self): """ Overwrite calling step on all optimizers as this needs to be done separately for the generator and discriminator during the training of a batch. """ pass def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: mu_maps_real = mu_maps # rename real mu maps for clarification # compute fake mu maps with generator mu_maps_fake = self.generator(recons) # note: the batch size may differ for the last batch which is why self.batch_size is not reliable batch_size = recons.shape[0] labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device) labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device) # prepare inputs for the discriminator inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) inputs_d_real = torch.cat((recons, mu_maps_real), dim=1) # ======================= Discriminator ===================================== # compute discriminator loss for fake mu maps # detach is called so that gradients are not computed for the generator outputs_d_fake = self.discriminator(inputs_d_fake.detach()) loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake) # compute discriminator loss for real mu maps outputs_d_real = self.discriminator(inputs_d_real) loss_d_real = self.criterion_adv(outputs_d_real, labels_real) # update discriminator loss_d = 0.5 * (loss_d_fake + loss_d_real) loss_d.backward() # compute gradients self.optim_d.step() # =========================================================================== # ======================= Generator ========================================= outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real) loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real) loss_g = ( self.weight_criterion_adv * loss_g_adv + self.weight_criterion_dist * loss_g_dist ) loss_g.backward() self.optim_g.step() # =========================================================================== return loss_g.item() def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: mu_maps_fake = self.generator(recons) loss = torch.nn.functional.l1_loss(mu_maps_fake, mu_maps) return loss.item() if __name__ == "__main__": import argparse import os import random import sys import numpy as np from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.normalization import ( MeanNormTransform, MaxNormTransform, GaussianNormTransform, ) from mu_map.dataset.transform import PadCropTranform, SequenceTransform from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet from mu_map.models.discriminator import Discriminator, PatchDiscriminator parser = argparse.ArgumentParser( description="Train a UNet model to predict μ-maps from reconstructed images", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Model Args parser.add_argument( "--features", type=int, nargs="+", default=[64, 128, 256, 512], help="number of features in the layers of the UNet structure", ) # Dataset Args parser.add_argument( "--dataset_dir", type=str, default="data/second/", help="the directory where the dataset for training is found", ) parser.add_argument( "--input_norm", type=str, choices=["none", "mean", "max", "gaussian"], default="mean", help="type of normalization applied to the reconstructions", ) parser.add_argument( "--patch_size", type=int, default=32, help="the size of patches extracted for each reconstruction", ) parser.add_argument( "--patch_offset", type=int, default=20, help="offset to ignore the border of the image", ) parser.add_argument( "--number_of_patches", type=int, default=100, help="number of patches extracted for each image", ) parser.add_argument( "--no_shuffle", action="store_true", help="do not shuffle patches in the dataset", ) parser.add_argument( "--scatter_correction", action="store_true", help="use the scatter corrected reconstructions in the dataset", ) # Training Args parser.add_argument( "--seed", type=int, help="seed used for random number generation", ) parser.add_argument( "--batch_size", type=int, default=64, help="the batch size used for training", ) parser.add_argument( "--output_dir", type=str, default="train_data", help="directory in which results (snapshots and logs) of this training are saved", ) parser.add_argument( "--epochs", type=int, default=100, help="the number of epochs for which the model is trained", ) parser.add_argument( "--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="the device (cpu or gpu) with which the training is performed", ) parser.add_argument( "--dist_loss_func", type=str, default="l1", help="define the loss function used as the distance loss of the generator , e.g. 0.75*l2+0.25*gdl", ) parser.add_argument( "--dist_loss_weight", type=float, default=100.0, help="weight for the distance loss of the generator", ) parser.add_argument( "--adv_loss_weight", type=float, default=1.0, help="weight for the Adversarial-Loss of the generator", ) parser.add_argument( "--lr", type=float, default=0.001, help="the initial learning rate for training" ) parser.add_argument( "--decay_lr", action="store_true", help="decay the learning rate", ) parser.add_argument( "--lr_decay_factor", type=float, default=0.99, help="decay factor for the learning rate", ) parser.add_argument( "--lr_decay_epoch", type=int, default=1, help="frequency in epochs at which the learning rate is decayed", ) parser.add_argument( "--snapshot_dir", type=str, default="snapshots", help="directory under --output_dir where snapshots are stored", ) parser.add_argument( "--snapshot_epoch", type=int, default=10, help="frequency in epochs at which snapshots are stored", ) # Logging Args add_logging_args(parser, defaults={"--logfile": "train.log"}) args = parser.parse_args() if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) if not os.path.exists(args.snapshot_dir): os.mkdir(args.snapshot_dir) else: if len(os.listdir(args.snapshot_dir)) > 0: print( f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!" ) print(f" Exit so that data is not accidentally overwritten!") exit(1) args.logfile = os.path.join(args.output_dir, args.logfile) logger = get_logger_by_args(args) logger.info(args) device = torch.device(args.device) args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {args.seed}") random.seed(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) transform_normalization = None if args.input_norm == "mean": transform_normalization = MeanNormTransform() elif args.input_norm == "max": transform_normalization = MaxNormTransform() elif args.input_norm == "gaussian": transform_normalization = GaussianNormTransform() transform_normalization = SequenceTransform( [transform_normalization, PadCropTranform(dim=3, size=32)] ) dataset = MuMapPatchDataset( args.dataset_dir, patches_per_image=args.number_of_patches, patch_size=args.patch_size, patch_offset=args.patch_offset, shuffle=not args.no_shuffle, transform_normalization=transform_normalization, scatter_correction=args.scatter_correction, logger=logger, ) discriminator = Discriminator(in_channels=2, input_size=args.patch_size) discriminator = discriminator.to(device) optimizer = torch.optim.Adam( discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999) ) lr_scheduler = ( torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.lr_decay_epoch, gamma=args.lr_decay_factor ) if args.decay_lr else None ) params_d = DiscriminatorParams( model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler ) generator = UNet(in_channels=1, features=args.features) generator = generator.to(device) optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999)) lr_scheduler = ( torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor ) if args.decay_lr else None ) params_g = GeneratorParams( model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler ) dist_criterion = WeightedLoss.from_str(args.dist_loss_func) logger.debug(f"Use distance criterion: {dist_criterion}") training = cGANTraining( epochs=args.epochs, dataset=dataset, batch_size=args.batch_size, device=device, snapshot_dir=args.snapshot_dir, snapshot_epoch=args.snapshot_epoch, params_generator=params_g, params_discriminator=params_d, loss_func_dist=dist_criterion, weight_criterion_dist=args.dist_loss_weight, weight_criterion_adv=args.adv_loss_weight, logger=logger, ) training.run()