Skip to content
Snippets Groups Projects
cgan2.py 13.4 KiB
Newer Older
  • Learn to ignore specific revisions
  • from logging import Logger
    from typing import Optional
    
    import torch
    
    from mu_map.dataset.default import MuMapDataset
    from mu_map.training.lib import TrainingParams, AbstractTraining
    from mu_map.training.loss import WeightedLoss
    
    
    # Establish convention for real and fake labels during training
    LABEL_REAL = 1.0
    LABEL_FAKE = 0.0
    
    
    class DiscriminatorParams(TrainingParams):
        """
        Wrap training parameters to always carry the name 'Discriminator'.
        """
        def __init__(
            self,
            model: torch.nn.Module,
            optimizer: torch.optim.Optimizer,
            lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
        ):
            super().__init__(
                name="Discriminator",
                model=model,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,
            )
    
    class GeneratorParams(TrainingParams):
        """
        Wrap training parameters to always carry the name 'Generator'.
        """
        def __init__(
            self,
            model: torch.nn.Module,
            optimizer: torch.optim.Optimizer,
            lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
        ):
            super().__init__(
                name="Generator",
                model=model,
                optimizer=optimizer,
                lr_scheduler=lr_scheduler,
            )
    
    
    
    class cGANTraining(AbstractTraining):
        """
        Implementation of a conditional generative adversarial network training.
        """
        def __init__(
            self,
            epochs: int,
            dataset: MuMapDataset,
            batch_size: int,
            device: torch.device,
            snapshot_dir: str,
            snapshot_epoch: int,
            params_generator: GeneratorParams,
            params_discriminator: DiscriminatorParams,
            loss_func_dist: WeightedLoss,
            weight_criterion_dist: float,
            weight_criterion_adv: float,
            logger: Optional[Logger] = None,
        ):
    
            """
            :param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
            :param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
            :param loss_func_dist: distance loss function for the generator
            :param weight_criterion_dist: weight of the distance loss when training the generator
            :param weight_criterion_adv: weight of the adversarial loss when training the generator
            """
    
            super().__init__(
                epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger
            )
            self.training_params.append(params_generator)
            self.training_params.append(params_discriminator)
    
            self.generator = params_generator.model
            self.discriminator = params_discriminator.model
    
            self.optim_g = params_generator.optimizer
            self.optim_d = params_discriminator.optimizer
    
            self.weight_criterion_dist = weight_criterion_dist
            self.weight_criterion_adv = weight_criterion_adv
    
            self.criterion_adv = torch.nn.MSELoss(reduction="mean")
            self.criterion_dist = loss_func_dist
    
        def _after_train_batch(self):
            """
            Overwrite calling step on all optimizers as this needs to be done
            separately for the generator and discriminator during the training of
            a batch.
            """
            pass
    
        def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
            mu_maps_real = mu_maps # rename real mu maps for clarification
            # compute fake mu maps with generator
            mu_maps_fake = self.generator(recons)
    
            # note: the batch size may differ for the last batch which is why self.batch_size is not reliable
            batch_size = recons.shape[0]
            labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device)
            labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device)
    
            # prepare inputs for the discriminator
            inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
            inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
    
            # ======================= Discriminator =====================================
            # compute discriminator loss for fake mu maps
            # detach is called so that gradients are not computed for the generator
            outputs_d_fake = self.discriminator(inputs_d_fake.detach())
            loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake)
    
            # compute discriminator loss for real mu maps
            outputs_d_real = self.discriminator(inputs_d_real)
            loss_d_real = self.criterion_adv(outputs_d_real, labels_real)
    
            # update discriminator
            loss_d = 0.5 * (loss_d_fake + loss_d_real)
            loss_d.backward()  # compute gradients
            self.optim_d.step()
            # ===========================================================================
    
            # ======================= Generator =========================================
            outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach
            loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
            loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
            loss_g = (
                self.weight_criterion_adv * loss_g_adv
                + self.weight_criterion_dist * loss_g_dist
            )
            loss_g.backward()
            self.optim_g.step()
            # ===========================================================================
    
            return loss_g.item()
    
        def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
            mu_maps_fake = self.generator(recons)
            loss = torch.nn.functional.l1_loss(mu_maps_fake, mu_maps)
            return loss.item()
    
    
    if __name__ == "__main__":
        import argparse
        import os
        import random
        import sys
    
        import numpy as np
    
        from mu_map.dataset.patches import MuMapPatchDataset
        from mu_map.dataset.normalization import (
            MeanNormTransform,
            MaxNormTransform,
            GaussianNormTransform,
        )
        from mu_map.dataset.transform import PadCropTranform, SequenceTransform
        from mu_map.logging import add_logging_args, get_logger_by_args
        from mu_map.models.unet import UNet
        from mu_map.models.discriminator import Discriminator, PatchDiscriminator
    
        parser = argparse.ArgumentParser(
            description="Train a UNet model to predict μ-maps from reconstructed images",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
    
        # Model Args
        parser.add_argument(
            "--features",
            type=int,
            nargs="+",
            default=[64, 128, 256, 512],
            help="number of features in the layers of the UNet structure",
        )
    
        # Dataset Args
        parser.add_argument(
            "--dataset_dir",
            type=str,
            default="data/second/",
            help="the directory where the dataset for training is found",
        )
        parser.add_argument(
            "--input_norm",
            type=str,
            choices=["none", "mean", "max", "gaussian"],
            default="mean",
            help="type of normalization applied to the reconstructions",
        )
        parser.add_argument(
            "--patch_size",
            type=int,
            default=32,
            help="the size of patches extracted for each reconstruction",
        )
        parser.add_argument(
            "--patch_offset",
            type=int,
            default=20,
            help="offset to ignore the border of the image",
        )
        parser.add_argument(
            "--number_of_patches",
            type=int,
            default=100,
            help="number of patches extracted for each image",
        )
        parser.add_argument(
            "--no_shuffle",
            action="store_true",
            help="do not shuffle patches in the dataset",
        )
        parser.add_argument(
            "--scatter_correction",
            action="store_true",
            help="use the scatter corrected reconstructions in the dataset",
        )
    
        # Training Args
        parser.add_argument(
            "--seed",
            type=int,
            help="seed used for random number generation",
        )
        parser.add_argument(
            "--batch_size",
            type=int,
            default=64,
            help="the batch size used for training",
        )
        parser.add_argument(
            "--output_dir",
            type=str,
            default="train_data",
            help="directory in which results (snapshots and logs) of this training are saved",
        )
        parser.add_argument(
            "--epochs",
            type=int,
            default=100,
            help="the number of epochs for which the model is trained",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda:0" if torch.cuda.is_available() else "cpu",
            help="the device (cpu or gpu) with which the training is performed",
        )
        parser.add_argument(
            "--dist_loss_func",
            type=str,
            default="l1",
            help="define the loss function used as the distance loss of the generator , e.g. 0.75*l2+0.25*gdl",
        )
        parser.add_argument(
            "--dist_loss_weight",
            type=float,
            default=100.0,
            help="weight for the distance loss of the generator",
        )
        parser.add_argument(
            "--adv_loss_weight",
            type=float,
            default=1.0,
            help="weight for the Adversarial-Loss of the generator",
        )
        parser.add_argument(
            "--lr", type=float, default=0.001, help="the initial learning rate for training"
        )
        parser.add_argument(
            "--decay_lr",
            action="store_true",
            help="decay the learning rate",
        )
        parser.add_argument(
            "--lr_decay_factor",
            type=float,
            default=0.99,
            help="decay factor for the learning rate",
        )
        parser.add_argument(
            "--lr_decay_epoch",
            type=int,
            default=1,
            help="frequency in epochs at which the learning rate is decayed",
        )
        parser.add_argument(
            "--snapshot_dir",
            type=str,
            default="snapshots",
            help="directory under --output_dir where snapshots are stored",
        )
        parser.add_argument(
            "--snapshot_epoch",
            type=int,
            default=10,
            help="frequency in epochs at which snapshots are stored",
        )
    
        # Logging Args
        add_logging_args(parser, defaults={"--logfile": "train.log"})
    
        args = parser.parse_args()
    
        if not os.path.exists(args.output_dir):
            os.mkdir(args.output_dir)
    
        args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
        if not os.path.exists(args.snapshot_dir):
            os.mkdir(args.snapshot_dir)
        else:
            if len(os.listdir(args.snapshot_dir)) > 0:
                print(
                    f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
                )
                print(f"           Exit so that data is not accidentally overwritten!")
                exit(1)
    
        args.logfile = os.path.join(args.output_dir, args.logfile)
    
        logger = get_logger_by_args(args)
        logger.info(args)
    
        device = torch.device(args.device)
    
        args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
        logger.info(f"Seed: {args.seed}")
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
    
        transform_normalization = None
        if args.input_norm == "mean":
            transform_normalization = MeanNormTransform()
        elif args.input_norm == "max":
            transform_normalization = MaxNormTransform()
        elif args.input_norm == "gaussian":
            transform_normalization = GaussianNormTransform()
        transform_normalization = SequenceTransform(
            [transform_normalization, PadCropTranform(dim=3, size=32)]
        )
    
        dataset = MuMapPatchDataset(
            args.dataset_dir,
            patches_per_image=args.number_of_patches,
            patch_size=args.patch_size,
            patch_offset=args.patch_offset,
            shuffle=not args.no_shuffle,
            transform_normalization=transform_normalization,
            scatter_correction=args.scatter_correction,
            logger=logger,
        )
    
        discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
        discriminator = discriminator.to(device)
        optimizer = torch.optim.Adam(
            discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999)
        )
        lr_scheduler = (
            torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=args.lr_decay_epoch, gamma=args.lr_decay_factor
            )
            if args.decay_lr
            else None
        )
        params_d = DiscriminatorParams(
            model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
        )
    
        generator = UNet(in_channels=1, features=args.features)
        generator = generator.to(device)
        optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
        lr_scheduler = (
            torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
            )
            if args.decay_lr
            else None
        )
        params_g = GeneratorParams(
            model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
        )
    
        dist_criterion = WeightedLoss.from_str(args.dist_loss_func)
        logger.debug(f"Use distance criterion: {dist_criterion}")
    
        training = cGANTraining(
            epochs=args.epochs,
            dataset=dataset,
            batch_size=args.batch_size,
            device=device,
            snapshot_dir=args.snapshot_dir,
            snapshot_epoch=args.snapshot_epoch,
            params_generator=params_g,
            params_discriminator=params_d,
            loss_func_dist=dist_criterion,
            weight_criterion_dist=args.dist_loss_weight,
            weight_criterion_adv=args.adv_loss_weight,
            logger=logger,
        )
        training.run()