Skip to content
Snippets Groups Projects
cgan.py 15 KiB
Newer Older
  • Learn to ignore specific revisions
  • from dataclasses import dataclass
    
    import os
    
    from typing import Dict, Optional
    import sys
    
    
    import torch
    from torch import Tensor
    
    
    from mu_map.training.loss import WeightedLoss
    
    from mu_map.logging import get_logger
    
    # Establish convention for real and fake labels during training
    LABEL_REAL = 1.0
    LABEL_FAKE = 0.0
    
    
    @dataclass
    class TrainingParams:
        model: torch.nn.Module
        optimizer: torch.optim.Optimizer
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
    
    
    
    class cGANTraining:
        def __init__(
            self,
            data_loaders: Dict[str, torch.utils.data.DataLoader],
            epochs: int,
            device: torch.device,
            snapshot_dir: str,
            snapshot_epoch: int,
    
            params_generator: torch.nn.Module,
            params_discriminator: torch.nn.Module,
            loss_func_dist: WeightedLoss,
            weight_criterion_dist: float,
            weight_criterion_adv: float,
    
            logger=None,
        ):
            self.data_loaders = data_loaders
            self.epochs = epochs
            self.device = device
    
            self.snapshot_dir = snapshot_dir
            self.snapshot_epoch = snapshot_epoch
            self.logger = logger if logger is not None else get_logger()
    
    
            self.params_g = params_generator
            self.params_d = params_discriminator
    
            self.weight_criterion_dist = weight_criterion_dist
            self.weight_criterion_adv = weight_criterion_adv
    
            self.criterion_adv = torch.nn.MSELoss(reduction="mean")
    
            self.criterion_dist = loss_func_dist
    
    
        def run(self):
    
            loss_val_min = sys.maxsize
    
            for epoch in range(1, self.epochs + 1):
    
                str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
    
                self._train_epoch()
                loss_train = self._eval_epoch("train")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                self.logger.info(
    
                    f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
                )
                loss_val = self._eval_epoch("validation")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                self.logger.info(
    
                    f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
                )
    
                if loss_val < loss_val_min:
                    loss_val_min = loss_val
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    self.logger.info(
    
                        f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
    
                    )
                    self.store_snapshot("val_min")
    
                if epoch % self.snapshot_epoch == 0:
    
                    self._store_snapshot(epoch)
    
                if self.params_d.lr_scheduler is not None:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    self.logger.debug("Step LR scheduler of discriminator")
    
                    self.params_d.lr_scheduler.step()
                if self.params_g.lr_scheduler is not None:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    self.logger.debug("Step LR scheduler of generator")
    
                    self.params_g.lr_scheduler.step()
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            return loss_val_min
    
    
        def _train_epoch(self):
    
            # setup training mode
            torch.set_grad_enabled(True)
            self.params_d.model.train()
            self.params_g.model.train()
    
    
            data_loader = self.data_loaders["train"]
    
            for i, (recons, mu_maps_real) in enumerate(data_loader):
    
                print(
                    f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                    end="\r",
                )
    
                batch_size = recons.shape[0]
    
    
                recons = recons.to(self.device)
    
                mu_maps_real = mu_maps_real.to(self.device)
    
                self.params_d.optimizer.zero_grad()
                self.params_g.optimizer.zero_grad()
    
                # compute fake mu maps with generator
    
                mu_maps_fake = self.params_g.model(recons)
    
                # compute discriminator loss for fake mu maps
                inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
    
                outputs_d_fake = self.params_d.model(
                    inputs_d_fake.detach()
                )  # note the detach, so that gradients are not computed for the generator
                labels_fake = torch.full(
                    (outputs_d_fake.shape), LABEL_FAKE, device=self.device
                )
                loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake)
    
    
                # compute discriminator loss for real mu maps
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
    
                outputs_d_real = self.params_d.model(
                    inputs_d_real
                )  # note the detach, so that gradients are not computed for the generator
                labels_real = torch.full(
                    (outputs_d_fake.shape), LABEL_REAL, device=self.device
                )
                loss_d_real = self.criterion_adv(outputs_d_real, labels_real)
    
    
                # update discriminator
                loss_d = 0.5 * (loss_d_fake + loss_d_real)
                loss_d.backward()  # compute gradients
    
                self.params_d.optimizer.step()
    
    
                # update generator
    
                inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
    
                outputs_d_fake = self.params_d.model(inputs_d_fake)
                loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
                loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
                loss_g = (
                    self.weight_criterion_adv * loss_g_adv
                    + self.weight_criterion_dist * loss_g_dist
                )
    
                loss_g.backward()
    
                self.params_g.optimizer.step()
    
        def _eval_epoch(self, split_name):
    
            torch.set_grad_enabled(False)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.params_d.model = self.params_d.model.eval()
            self.params_g.model = self.params_g.model.eval()
    
            data_loader = self.data_loaders[split_name]
    
    
            loss = 0.0
            updates = 0
            for i, (recons, mu_maps) in enumerate(data_loader):
                print(
                    f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                    end="\r",
                )
                recons = recons.to(self.device)
                mu_maps = mu_maps.to(self.device)
    
    
                outputs = self.params_g.model(recons)
    
    
                loss += torch.nn.functional.l1_loss(outputs, mu_maps)
                updates += 1
    
        def _store_snapshot(self, epoch):
            prefix = f"{epoch:0{len(str(self.epochs))}d}"
            self.store_snapshot(prefix)
    
        def store_snapshot(self, prefix: str):
            snapshot_file_d = os.path.join(self.snapshot_dir, f"{prefix}_discriminator.pth")
            snapshot_file_g = os.path.join(self.snapshot_dir, f"{prefix}_generator.pth")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
            torch.save(self.params_d.model.state_dict(), snapshot_file_d)
            torch.save(self.params_g.model.state_dict(), snapshot_file_g)
    
    
    
    if __name__ == "__main__":
        import argparse
        import random
        import sys
    
        import numpy as np
    
        from mu_map.dataset.patches import MuMapPatchDataset
        from mu_map.dataset.normalization import (
            MeanNormTransform,
            MaxNormTransform,
            GaussianNormTransform,
        )
    
        from mu_map.dataset.transform import PadCropTranform, SequenceTransform
    
        from mu_map.logging import add_logging_args, get_logger_by_args
        from mu_map.models.unet import UNet
    
        from mu_map.models.discriminator import Discriminator, PatchDiscriminator
    
    
        parser = argparse.ArgumentParser(
            description="Train a UNet model to predict μ-maps from reconstructed scatter images",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
    
        # Model Args
        parser.add_argument(
            "--features",
            type=int,
            nargs="+",
            default=[64, 128, 256, 512],
            help="number of features in the layers of the UNet structure",
        )
    
        # Dataset Args
        parser.add_argument(
            "--dataset_dir",
            type=str,
    
            default="data/second/",
    
            help="the directory where the dataset for training is found",
        )
        parser.add_argument(
            "--input_norm",
            type=str,
            choices=["none", "mean", "max", "gaussian"],
            default="mean",
            help="type of normalization applied to the reconstructions",
        )
        parser.add_argument(
            "--patch_size",
            type=int,
            default=32,
            help="the size of patches extracted for each reconstruction",
        )
        parser.add_argument(
            "--patch_offset",
            type=int,
            default=20,
            help="offset to ignore the border of the image",
        )
        parser.add_argument(
            "--number_of_patches",
            type=int,
    
            help="number of patches extracted for each image",
        )
        parser.add_argument(
            "--no_shuffle",
            action="store_true",
            help="do not shuffle patches in the dataset",
        )
    
            "--scatter_correction",
    
            action="store_true",
            help="use the scatter corrected reconstructions in the dataset",
        )
    
    
        # Training Args
        parser.add_argument(
            "--seed",
            type=int,
            help="seed used for random number generation",
        )
        parser.add_argument(
            "--batch_size",
            type=int,
    
            help="the batch size used for training",
        )
        parser.add_argument(
            "--output_dir",
            type=str,
            default="train_data",
            help="directory in which results (snapshots and logs) of this training are saved",
        )
        parser.add_argument(
            "--epochs",
            type=int,
            default=100,
            help="the number of epochs for which the model is trained",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda:0" if torch.cuda.is_available() else "cpu",
            help="the device (cpu or gpu) with which the training is performed",
        )
    
        parser.add_argument(
    
            "--dist_loss_func",
            type=str,
            default="l1",
            help="define the loss function used as the distance loss of the generator , e.g. 0.75*l2+0.25*gdl",
    
        )
        parser.add_argument(
    
            default=100.0,
            help="weight for the distance loss of the generator",
    
        )
        parser.add_argument(
            "--adv_loss_weight",
            type=float,
    
            help="weight for the Adversarial-Loss of the generator",
        )
    
        parser.add_argument(
            "--lr", type=float, default=0.001, help="the initial learning rate for training"
        )
    
        parser.add_argument(
            "--decay_lr",
            action="store_true",
            help="decay the learning rate",
        )
    
        parser.add_argument(
            "--lr_decay_factor",
            type=float,
            default=0.99,
            help="decay factor for the learning rate",
        )
        parser.add_argument(
            "--lr_decay_epoch",
            type=int,
            default=1,
            help="frequency in epochs at which the learning rate is decayed",
        )
        parser.add_argument(
            "--snapshot_dir",
            type=str,
            default="snapshots",
            help="directory under --output_dir where snapshots are stored",
        )
        parser.add_argument(
            "--snapshot_epoch",
            type=int,
            default=10,
            help="frequency in epochs at which snapshots are stored",
        )
    
        parser.add_argument(
            "--generator_weights",
            type=str,
            help="use pre-trained weights for the generator",
        )
    
    
        # Logging Args
        add_logging_args(parser, defaults={"--logfile": "train.log"})
    
        args = parser.parse_args()
    
        if not os.path.exists(args.output_dir):
            os.mkdir(args.output_dir)
    
        args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
        if not os.path.exists(args.snapshot_dir):
            os.mkdir(args.snapshot_dir)
        else:
            if len(os.listdir(args.snapshot_dir)) > 0:
                print(
                    f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
                )
                print(f"           Exit so that data is not accidentally overwritten!")
                exit(1)
    
        args.logfile = os.path.join(args.output_dir, args.logfile)
    
        device = torch.device(args.device)
        logger = get_logger_by_args(args)
        logger.info(args)
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
    
        logger.info(f"Seed: {args.seed}")
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
    
        transform_normalization = None
        if args.input_norm == "mean":
            transform_normalization = MeanNormTransform()
        elif args.input_norm == "max":
            transform_normalization = MaxNormTransform()
        elif args.input_norm == "gaussian":
            transform_normalization = GaussianNormTransform()
    
        transform_normalization = SequenceTransform(
            [transform_normalization, PadCropTranform(dim=3, size=32)]
        )
    
    
        data_loaders = {}
        for split in ["train", "validation"]:
            dataset = MuMapPatchDataset(
                args.dataset_dir,
                patches_per_image=args.number_of_patches,
                patch_size=args.patch_size,
                patch_offset=args.patch_offset,
                shuffle=not args.no_shuffle,
                split_name=split,
                transform_normalization=transform_normalization,
    
                scatter_correction=args.scatter_correction,
    
                logger=logger,
            )
            data_loader = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=True,
                num_workers=1,
            )
            data_loaders[split] = data_loader
    
    
        discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
        discriminator = discriminator.to(device)
        optimizer = torch.optim.Adam(
            discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999)
        )
        lr_scheduler = (
            torch.optim.lr_scheduler.StepLR(
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                optimizer, step_size=args.lr_decay_epoch, gamma=args.lr_decay_factor
    
            )
            if args.decay_lr
            else None
        )
        params_d = TrainingParams(
            model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
        )
    
        generator = UNet(in_channels=1, features=args.features)
        generator = generator.to(device)
        if args.generator_weights:
            logger.debug(f"Load generator weights from {args.generator_weights}")
            generator.load_state_dict(
                torch.load(args.generator_weights, map_location=device)
            )
        optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
        lr_scheduler = (
            torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
            )
            if args.decay_lr
            else None
        )
        params_g = TrainingParams(
            model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
        )
    
        dist_criterion = WeightedLoss.from_str(args.dist_loss_func)
    
        logger.debug(f"Use distance criterion: {dist_criterion}")
    
        training = cGANTraining(
            data_loaders=data_loaders,
            epochs=args.epochs,
            device=device,
            snapshot_dir=args.snapshot_dir,
            snapshot_epoch=args.snapshot_epoch,
            logger=logger,
    
            params_generator=params_g,
            params_discriminator=params_d,
            loss_func_dist=dist_criterion,
            weight_criterion_dist=args.dist_loss_weight,
            weight_criterion_adv=args.adv_loss_weight,