From 92f48bfbedf70cbd75aebe5363e8af68d585d811 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Wed, 5 Oct 2022 15:55:49 +0200
Subject: [PATCH] implement conditional GAN training

---
 mu_map/training/cgan.py | 457 ++++++++++++++++++++++++++++++++++++++++
 mu_map/training/loss.py |  27 +++
 2 files changed, 484 insertions(+)
 create mode 100644 mu_map/training/cgan.py
 create mode 100644 mu_map/training/loss.py

diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
new file mode 100644
index 0000000..aa93a31
--- /dev/null
+++ b/mu_map/training/cgan.py
@@ -0,0 +1,457 @@
+import os
+from typing import Dict
+
+import torch
+from torch import Tensor
+
+from mu_map.training.loss import GradientDifferenceLoss
+from mu_map.logging import get_logger
+
+# Establish convention for real and fake labels during training
+LABEL_REAL = 1.0
+LABEL_FAKE = 0.0
+
+
+class GeneratorLoss(torch.nn.Module):
+    def __init__(
+        self, l2_weight: float = 1.0, gdl_weight: float = 1.0, adv_weight: float = 20.0
+    ):
+        super().__init__()
+
+        self.l2 = torch.nn.MSELoss(reduction="mean")
+        self.l2_weight = l2_weight
+
+        self.gdl = GradientDifferenceLoss()
+        self.gdl_weight = gdl_weight
+
+        self.adv = torch.nn.MSELoss(reduction="mean")
+        self.adv_weight = adv_weight
+
+    def forward(
+        self,
+        mu_maps_real: Tensor,
+        outputs_g: Tensor,
+        targets_d: Tensor,
+        outputs_d: Tensor,
+    ):
+        loss_l2 = self.l2(outputs_g, mu_maps_real)
+        loss_gdl = self.gdl(outputs_g, mu_maps_real)
+        loss_adv = self.adv(outputs_d, targets_d)
+
+        return (
+            self.l2_weight * loss_l2
+            + self.gdl_weight * loss_gdl
+            + self.adv_weight * loss_adv
+        )
+
+
+class cGANTraining:
+    def __init__(
+        self,
+        generator: torch.nn.Module,
+        discriminator: torch.nn.Module,
+        data_loaders: Dict[str, torch.utils.data.DataLoader],
+        epochs: int,
+        device: torch.device,
+        lr_d: float,
+        lr_decay_factor_d: float,
+        lr_decay_epoch_d: int,
+        lr_g: float,
+        lr_decay_factor_g: float,
+        lr_decay_epoch_g: int,
+        l2_weight: float,
+        gdl_weight: float,
+        adv_weight: float,
+        snapshot_dir: str,
+        snapshot_epoch: int,
+        logger=None,
+    ):
+        self.generator = generator
+        self.discriminator = discriminator
+
+        self.data_loaders = data_loaders
+        self.epochs = epochs
+        self.device = device
+
+        self.snapshot_dir = snapshot_dir
+        self.snapshot_epoch = snapshot_epoch
+
+        self.logger = logger if logger is not None else get_logger()
+
+        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)
+        self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g)
+
+        self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR(
+            self.optimizer_d,
+            step_size=lr_decay_epoch_d,
+            gamma=lr_decay_factor_d,
+        )
+        self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR(
+            self.optimizer_g,
+            step_size=lr_decay_epoch_g,
+            gamma=lr_decay_factor_g,
+        )
+
+        self.criterion_d = torch.nn.MSELoss(reduction="mean")
+        self.criterion_g = GeneratorLoss(
+            l2_weight=l2_weight, gdl_weight=gdl_weight, adv_weight=adv_weight
+        )
+
+    def run(self):
+        losses_d = []
+        losses_g = []
+        for epoch in range(1, self.epochs + 1):
+            logger.debug(
+                f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
+            )
+            _losses_d, _losses_g = self._train_epoch()
+            losses_d.extend(_losses_d)
+            losses_g.extend(_losses_g)
+
+            self._eval_epoch(epoch, "train")
+            self._eval_epoch(epoch, "validation")
+
+            self.lr_scheduler_d.step()
+            self.lr_scheduler_g.step()
+
+            if epoch % self.snapshot_epoch == 0:
+                self.store_snapshot(epoch)
+
+            logger.debug(
+                f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
+            )
+        return losses_d, losses_g
+
+    def _train_epoch(self):
+        logger.debug(f"Train epoch")
+        torch.set_grad_enabled(True)
+
+        self.discriminator = self.discriminator.train()
+        self.generator = self.generator.train()
+
+        losses_d = []
+        losses_g = []
+
+        data_loader = self.data_loaders["train"]
+        for i, (recons, mu_maps) in enumerate(data_loader):
+            print(
+                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
+                end="\r",
+            )
+            recons = recons.to(self.device)
+            mu_maps = mu_maps.to(self.device)
+
+            loss_d_real, loss_d_fake, loss_g = self._step(recons, mu_maps)
+
+            losses_d.append(loss_d_real + loss_d_fake)
+            losses_g.append(loss_g)
+        return losses_d, losses_g
+
+    def _step(self, recons, mu_maps_real):
+        batch_size = recons.shape[0]
+
+        self.optimizer_d.zero_grad()
+        self.optimizer_g.zero_grad()
+
+        labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device)
+        labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device)
+
+        with torch.set_grad_enabled(True):
+            # compute fake mu maps with generator
+            mu_maps_fake = self.generator(recons)
+
+            # update discriminator based on real mu maps
+            outputs_d = self.discriminator(mu_maps_real)
+            loss_d_real = self.criterion_d(outputs_d, labels_real)
+            loss_d_real.backward()  # compute gradients
+            # update discriminator based on fake mu maps
+            outputs_d = self.discriminator(
+                mu_maps_fake.detach()
+            )  # note the detach, so that gradients are not computed for the generator
+            loss_d_fake = self.criterion_d(outputs_d, labels_fake)
+            loss_d_fake.backward()  # compute gradients
+            self.optimizer_d.step()  # update discriminator based on gradients
+
+            # update generator
+            outputs_d = self.discriminator(mu_maps_fake)
+            loss_g = self.criterion_g(mu_maps_real, mu_maps_fake, labels_real, outputs_d)
+            loss_g.backward()
+            self.optimizer_g.step()
+
+        return loss_d_real.item(), loss_d_fake.item(), loss_g.item()
+
+    def _eval_epoch(self, epoch, split_name):
+        logger.debug(f"Evaluate epoch on split {split_name}")
+        torch.set_grad_enabled(False)
+
+        self.discriminator = self.discriminator.eval()
+        self.generator = self.generator.eval()
+
+        loss = 0.0
+        updates = 0
+        
+        data_loader = self.data_loaders[split_name]
+        for i, (recons, mu_maps) in enumerate(data_loader):
+            print(
+                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
+                end="\r",
+            )
+            recons = recons.to(self.device)
+            mu_maps = mu_maps.to(self.device)
+
+            outputs = self.generator(recons)
+
+            loss += torch.nn.functional.l1_loss(outputs, mu_maps)
+            updates += 1
+        loss = loss / updates
+        logger.info(
+            f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss {split_name}: {loss:.6f}"
+        )
+
+    def store_snapshot(self, epoch):
+        snapshot_file_d = f"{epoch:0{len(str(self.epochs))}d}_discriminator.pth"
+        snapshot_file_d = os.path.join(self.snapshot_dir, snapshot_file_d)
+
+        snapshot_file_g = f"{epoch:0{len(str(self.epochs))}d}_generator.pth"
+        snapshot_file_g = os.path.join(self.snapshot_dir, snapshot_file_g)
+        logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
+        torch.save(self.discriminator.state_dict(), snapshot_file_d)
+        torch.save(self.generator.state_dict(), snapshot_file_g)
+
+
+if __name__ == "__main__":
+    import argparse
+    import random
+    import sys
+
+    import numpy as np
+
+    from mu_map.dataset.patches import MuMapPatchDataset
+    from mu_map.dataset.normalization import (
+        MeanNormTransform,
+        MaxNormTransform,
+        GaussianNormTransform,
+    )
+    from mu_map.dataset.transform import ScaleTransform
+    from mu_map.logging import add_logging_args, get_logger_by_args
+    from mu_map.models.unet import UNet
+    from mu_map.models.discriminator import Discriminator
+
+    parser = argparse.ArgumentParser(
+        description="Train a UNet model to predict μ-maps from reconstructed scatter images",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Model Args
+    parser.add_argument(
+        "--features",
+        type=int,
+        nargs="+",
+        default=[64, 128, 256, 512],
+        help="number of features in the layers of the UNet structure",
+    )
+
+    # Dataset Args
+    parser.add_argument(
+        "--dataset_dir",
+        type=str,
+        default="data/initial/",
+        help="the directory where the dataset for training is found",
+    )
+    parser.add_argument(
+        "--output_scale",
+        type=float,
+        default=1.0,
+        help="scale the attenuation map by this coefficient",
+    )
+    parser.add_argument(
+        "--input_norm",
+        type=str,
+        choices=["none", "mean", "max", "gaussian"],
+        default="mean",
+        help="type of normalization applied to the reconstructions",
+    )
+    parser.add_argument(
+        "--patch_size",
+        type=int,
+        default=32,
+        help="the size of patches extracted for each reconstruction",
+    )
+    parser.add_argument(
+        "--patch_offset",
+        type=int,
+        default=20,
+        help="offset to ignore the border of the image",
+    )
+    parser.add_argument(
+        "--number_of_patches",
+        type=int,
+        default=1,
+        help="number of patches extracted for each image",
+    )
+    parser.add_argument(
+        "--no_shuffle",
+        action="store_true",
+        help="do not shuffle patches in the dataset",
+    )
+
+    # Training Args
+    parser.add_argument(
+        "--seed",
+        type=int,
+        help="seed used for random number generation",
+    )
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        default=8,
+        help="the batch size used for training",
+    )
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="train_data",
+        help="directory in which results (snapshots and logs) of this training are saved",
+    )
+    parser.add_argument(
+        "--epochs",
+        type=int,
+        default=100,
+        help="the number of epochs for which the model is trained",
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda:0" if torch.cuda.is_available() else "cpu",
+        help="the device (cpu or gpu) with which the training is performed",
+    )
+    parser.add_argument(
+        "--lr", type=float, default=0.001, help="the initial learning rate for training"
+    )
+    parser.add_argument(
+        "--lr_decay_factor",
+        type=float,
+        default=0.99,
+        help="decay factor for the learning rate",
+    )
+    parser.add_argument(
+        "--lr_decay_epoch",
+        type=int,
+        default=1,
+        help="frequency in epochs at which the learning rate is decayed",
+    )
+    parser.add_argument(
+        "--snapshot_dir",
+        type=str,
+        default="snapshots",
+        help="directory under --output_dir where snapshots are stored",
+    )
+    parser.add_argument(
+        "--snapshot_epoch",
+        type=int,
+        default=10,
+        help="frequency in epochs at which snapshots are stored",
+    )
+
+    # Logging Args
+    add_logging_args(parser, defaults={"--logfile": "train.log"})
+
+    args = parser.parse_args()
+
+    if not os.path.exists(args.output_dir):
+        os.mkdir(args.output_dir)
+
+    args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
+    if not os.path.exists(args.snapshot_dir):
+        os.mkdir(args.snapshot_dir)
+    else:
+        if len(os.listdir(args.snapshot_dir)) > 0:
+            print(
+                f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
+            )
+            print(f"           Exit so that data is not accidentally overwritten!")
+            exit(1)
+
+    args.logfile = os.path.join(args.output_dir, args.logfile)
+
+    device = torch.device(args.device)
+    logger = get_logger_by_args(args)
+    logger.info(args)
+
+    args.seed = args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1)
+    logger.info(f"Seed: {args.seed}")
+    random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    np.random.seed(args.seed)
+
+    discriminator = Discriminator(in_channels=1, input_size=args.patch_size)
+    discriminator = discriminator.to(device)
+
+    generator = UNet(in_channels=1, features=args.features)
+    generator = generator.to(device)
+
+    transform_normalization = None
+    if args.input_norm == "mean":
+        transform_normalization = MeanNormTransform()
+    elif args.input_norm == "max":
+        transform_normalization = MaxNormTransform()
+    elif args.input_norm == "gaussian":
+        transform_normalization = GaussianNormTransform()
+
+    transform_augmentation = ScaleTransform(scale_outputs=args.output_scale)
+
+    data_loaders = {}
+    for split in ["train", "validation"]:
+        dataset = MuMapPatchDataset(
+            args.dataset_dir,
+            patches_per_image=args.number_of_patches,
+            patch_size=args.patch_size,
+            patch_offset=args.patch_offset,
+            shuffle=not args.no_shuffle,
+            split_name=split,
+            transform_normalization=transform_normalization,
+            transform_augmentation=transform_augmentation,
+            logger=logger,
+        )
+        data_loader = torch.utils.data.DataLoader(
+            dataset=dataset,
+            batch_size=args.batch_size,
+            shuffle=True,
+            pin_memory=True,
+            num_workers=1,
+        )
+        data_loaders[split] = data_loader
+
+    training = cGANTraining(
+        discriminator=discriminator,
+        generator=generator,
+        data_loaders=data_loaders,
+        epochs=args.epochs,
+        device=device,
+        lr_d=0.0005,
+        lr_decay_factor_d=0.99,
+        lr_decay_epoch_d=1,
+        lr_g=0.001,
+        lr_decay_factor_g=0.99,
+        lr_decay_epoch_g=1,
+        l2_weight=1.0,
+        gdl_weight=1.0,
+        adv_weight=20.0,
+        snapshot_dir=args.snapshot_dir,
+        snapshot_epoch=args.snapshot_epoch,
+        logger=logger,
+    )
+    losses_d, losses_g = training.run()
+
+    import matplotlib.pyplot as plt
+
+    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
+    axs[0].plot(losses_d)
+    axs[0].set_title("Discriminator Loss")
+    axs[0].set_xlabel("Iteration")
+    axs[0].set_ylabel("Loss")
+    axs[1].plot(losses_g, label="Generator")
+    axs[1].set_title("Generator Loss")
+    axs[1].set_xlabel("Iteration")
+    axs[1].set_ylabel("Loss")
+    plt.savefig("losses.png")
diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py
new file mode 100644
index 0000000..b62c07e
--- /dev/null
+++ b/mu_map/training/loss.py
@@ -0,0 +1,27 @@
+import torch
+import torch.nn as nn
+
+
+class GradientDifferenceLoss(nn.Module):
+    """
+    Gradient Difference Loss (GDL) inspired by https://github.com/mmany/pytorch-GDL/blob/main/custom_loss_functions.py.
+    It is modified to deal with 5D tensors (batch_size, channels, z, y, x).
+    """
+
+    def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
+        gradient_diff_z = (inputs.diff(dim=2) - targets.diff(axis=2)).pow(2).sum()
+        gradient_diff_y = (inputs.diff(dim=3) - targets.diff(axis=3)).pow(2).sum()
+        gradient_diff_x = (inputs.diff(dim=4) - targets.diff(axis=4)).pow(2).sum()
+
+        gradient_diff = gradient_diff_x + gradient_diff_y + gradient_diff_z
+        return gradient_diff / inputs.numel()
+
+
+if __name__ == "__main__":
+    torch.manual_seed(10)
+    inputs = torch.rand((4, 1, 32, 64, 64))
+    targets = torch.rand((4, 1, 32, 64, 64))
+
+    criterion = GradientDifferenceLoss()
+    loss = criterion(inputs, targets)
+    print(f"Loss: {loss.item():.6f}")
-- 
GitLab