Skip to content
Snippets Groups Projects
cgan.py 15.78 KiB
import os
from typing import Dict

import torch
from torch import Tensor

from mu_map.training.loss import GradientDifferenceLoss
from mu_map.logging import get_logger

# Establish convention for real and fake labels during training
LABEL_REAL = 1.0
LABEL_FAKE = 0.0


# class GeneratorLoss(torch.nn.Module):
    # def __init__(
        # self,
        # # l2_weight: float = 1.0,
        # # gdl_weight: float = 1.0,
        # # adv_weight: float = 20.0,
        # # logger=None,
    # ):
        # super().__init__()

        # # self.l2 = torch.nn.MSELoss(reduction="mean")
        # self.l2 = torch.nn.L1Loss(reduction="mean")
        # self.l2_weight = l2_weight

        # self.gdl = GradientDifferenceLoss()
        # self.gdl_weight = gdl_weight

        # self.adv = torch.nn.MSELoss(reduction="mean")
        # self.adv_weight = adv_weight

        # if logger:
            # logger.debug(f"GeneratorLoss: {self}")

    # def __repr__(self):
        # return f"{self.l2_weight:.3f} * MSELoss + {self.gdl_weight:.3f} * GDLLoss + {self.adv_weight:.3f} * AdversarialLoss"

    # def forward(
        # self,
        # mu_maps_real: Tensor,
        # outputs_g: Tensor,
        # targets_d: Tensor,
        # outputs_d: Tensor,
    # ):
        # loss_l2 = self.l2(outputs_g, mu_maps_real)
        # loss_gdl = self.gdl(outputs_g, mu_maps_real)
        # loss_adv = self.adv(outputs_d, targets_d)

        # return (
            # self.l2_weight * loss_l2
            # + self.gdl_weight * loss_gdl
            # + self.adv_weight * loss_adv
        # )


class cGANTraining:
    def __init__(
        self,
        generator: torch.nn.Module,
        discriminator: torch.nn.Module,
        data_loaders: Dict[str, torch.utils.data.DataLoader],
        epochs: int,
        device: torch.device,
        lr_d: float,
        lr_decay_factor_d: float,
        lr_decay_epoch_d: int,
        lr_g: float,
        lr_decay_factor_g: float,
        lr_decay_epoch_g: int,
        l2_weight: float,
        gdl_weight: float,
        adv_weight: float,
        snapshot_dir: str,
        snapshot_epoch: int,
        logger=None,
    ):
        self.generator = generator
        self.discriminator = discriminator

        self.data_loaders = data_loaders
        self.epochs = epochs
        self.device = device

        self.snapshot_dir = snapshot_dir
        self.snapshot_epoch = snapshot_epoch

        self.logger = logger if logger is not None else get_logger()

        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))
        self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g, betas=(0.5, 0.999))

        # self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR(
            # self.optimizer_d,
            # step_size=lr_decay_epoch_d,
            # gamma=lr_decay_factor_d,
        # )
        # self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR(
            # self.optimizer_g,
            # step_size=lr_decay_epoch_g,
            # gamma=lr_decay_factor_g,
        # )

        self.criterion_d = torch.nn.MSELoss(reduction="mean")
        # self.criterion_g = GeneratorLoss(
            # l2_weight=l2_weight,
            # gdl_weight=gdl_weight,
            # adv_weight=adv_weight,
            # logger=self.logger,
        # )
        self.criterion_l1 = torch.nn.L1Loss(reduction="mean")

    def run(self):
        losses_d = []
        losses_g = []
        for epoch in range(1, self.epochs + 1):
            logger.debug(
                f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
            )
            _losses_d, _losses_g = self._train_epoch()
            losses_d.extend(_losses_d)
            losses_g.extend(_losses_g)

            self._eval_epoch(epoch, "train")
            self._eval_epoch(epoch, "validation")

            # self.lr_scheduler_d.step()
            # self.lr_scheduler_g.step()

            if epoch % self.snapshot_epoch == 0:
                self.store_snapshot(epoch)

            logger.debug(
                f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
            )
        return losses_d, losses_g

    def _train_epoch(self):
        logger.debug(f"Train epoch")
        torch.set_grad_enabled(True)

        self.discriminator = self.discriminator.train()
        self.generator = self.generator.train()

        losses_d = []
        losses_g = []

        data_loader = self.data_loaders["train"]
        for i, (recons, mu_maps) in enumerate(data_loader):
            print(
                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                end="\r",
            )
            recons = recons.to(self.device)
            mu_maps = mu_maps.to(self.device)

            loss_d_real, loss_d_fake, loss_g = self._step(recons, mu_maps)

            losses_d.append(loss_d_real + loss_d_fake)
            losses_g.append(loss_g)
        return losses_d, losses_g

    def _step(self, recons, mu_maps_real):
        batch_size = recons.shape[0]

        with torch.set_grad_enabled(True):
            self.optimizer_d.zero_grad()
            self.optimizer_g.zero_grad()

            # compute fake mu maps with generator
            mu_maps_fake = self.generator(recons)

            # compute discriminator loss for fake mu maps
            inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
            outputs_d_fake = self.discriminator(inputs_d_fake.detach())  # note the detach, so that gradients are not computed for the generator
            labels_fake = torch.full((outputs_d_fake.shape), LABEL_FAKE, device=self.device)
            loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake)

            # compute discriminator loss for real mu maps
            inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
            outputs_d_real = self.discriminator(inputs_d_real)  # note the detach, so that gradients are not computed for the generator
            labels_real = torch.full((outputs_d_fake.shape), LABEL_REAL, device=self.device)
            loss_d_real = self.criterion_d(outputs_d_real, labels_real)

            # update discriminator
            loss_d = 0.5 * (loss_d_fake + loss_d_real)
            loss_d.backward()  # compute gradients
            self.optimizer_d.step()

            # update generator
            inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
            outputs_d_fake = self.discriminator(inputs_d_fake)
            loss_g_adv = self.criterion_d(outputs_d_fake, labels_real)
            loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real)
            loss_g = loss_g_adv + 100.0 * loss_g_l1
            loss_g.backward()
            self.optimizer_g.step()

        return loss_d_real.item(), loss_d_fake.item(), loss_g.item()

    def _eval_epoch(self, epoch, split_name):
        logger.debug(f"Evaluate epoch on split {split_name}")
        torch.set_grad_enabled(False)

        self.discriminator = self.discriminator.eval()
        self.generator = self.generator.eval()

        loss = 0.0
        updates = 0

        data_loader = self.data_loaders[split_name]
        for i, (recons, mu_maps) in enumerate(data_loader):
            print(
                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                end="\r",
            )
            recons = recons.to(self.device)
            mu_maps = mu_maps.to(self.device)

            outputs = self.generator(recons)

            loss += torch.nn.functional.l1_loss(outputs, mu_maps)
            updates += 1
        loss = loss / updates
        logger.info(
            f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss {split_name}: {loss:.6f}"
        )

    def store_snapshot(self, epoch):
        snapshot_file_d = f"{epoch:0{len(str(self.epochs))}d}_discriminator.pth"
        snapshot_file_d = os.path.join(self.snapshot_dir, snapshot_file_d)

        snapshot_file_g = f"{epoch:0{len(str(self.epochs))}d}_generator.pth"
        snapshot_file_g = os.path.join(self.snapshot_dir, snapshot_file_g)
        logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
        torch.save(self.discriminator.state_dict(), snapshot_file_d)
        torch.save(self.generator.state_dict(), snapshot_file_g)


if __name__ == "__main__":
    import argparse
    import random
    import sys

    import numpy as np

    from mu_map.dataset.patches import MuMapPatchDataset
    from mu_map.dataset.normalization import (
        MeanNormTransform,
        MaxNormTransform,
        GaussianNormTransform,
    )
    from mu_map.dataset.transform import ScaleTransform
    from mu_map.logging import add_logging_args, get_logger_by_args
    from mu_map.models.unet import UNet
    from mu_map.models.discriminator import Discriminator, PatchDiscriminator

    parser = argparse.ArgumentParser(
        description="Train a UNet model to predict μ-maps from reconstructed scatter images",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    # Model Args
    parser.add_argument(
        "--features",
        type=int,
        nargs="+",
        default=[64, 128, 256, 512],
        help="number of features in the layers of the UNet structure",
    )

    # Dataset Args
    parser.add_argument(
        "--dataset_dir",
        type=str,
        default="data/initial/",
        help="the directory where the dataset for training is found",
    )
    parser.add_argument(
        "--input_norm",
        type=str,
        choices=["none", "mean", "max", "gaussian"],
        default="mean",
        help="type of normalization applied to the reconstructions",
    )
    parser.add_argument(
        "--patch_size",
        type=int,
        default=32,
        help="the size of patches extracted for each reconstruction",
    )
    parser.add_argument(
        "--patch_offset",
        type=int,
        default=20,
        help="offset to ignore the border of the image",
    )
    parser.add_argument(
        "--number_of_patches",
        type=int,
        default=100,
        help="number of patches extracted for each image",
    )
    parser.add_argument(
        "--no_shuffle",
        action="store_true",
        help="do not shuffle patches in the dataset",
    )

    # Training Args
    parser.add_argument(
        "--seed",
        type=int,
        help="seed used for random number generation",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=64,
        help="the batch size used for training",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="train_data",
        help="directory in which results (snapshots and logs) of this training are saved",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=100,
        help="the number of epochs for which the model is trained",
    )
    parser.add_argument(
        "--device",
        type=str,
        default="cuda:0" if torch.cuda.is_available() else "cpu",
        help="the device (cpu or gpu) with which the training is performed",
    )
    parser.add_argument(
        "--mse_loss_weight",
        type=float,
        default=1.0,
        help="weight for the L2-Loss of the generator",
    )
    parser.add_argument(
        "--gdl_loss_weight",
        type=float,
        default=1.0,
        help="weight for the Gradient-Difference-Loss of the generator",
    )
    parser.add_argument(
        "--adv_loss_weight",
        type=float,
        default=20.0,
        help="weight for the Adversarial-Loss of the generator",
    )
    parser.add_argument(
        "--lr", type=float, default=0.001, help="the initial learning rate for training"
    )
    parser.add_argument(
        "--lr_decay_factor",
        type=float,
        default=0.99,
        help="decay factor for the learning rate",
    )
    parser.add_argument(
        "--lr_decay_epoch",
        type=int,
        default=1,
        help="frequency in epochs at which the learning rate is decayed",
    )
    parser.add_argument(
        "--snapshot_dir",
        type=str,
        default="snapshots",
        help="directory under --output_dir where snapshots are stored",
    )
    parser.add_argument(
        "--snapshot_epoch",
        type=int,
        default=10,
        help="frequency in epochs at which snapshots are stored",
    )
    parser.add_argument(
        "--generator_weights",
        type=str,
        help="use pre-trained weights for the generator",
    )

    # Logging Args
    add_logging_args(parser, defaults={"--logfile": "train.log"})

    args = parser.parse_args()

    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
    if not os.path.exists(args.snapshot_dir):
        os.mkdir(args.snapshot_dir)
    else:
        if len(os.listdir(args.snapshot_dir)) > 0:
            print(
                f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
            )
            print(f"           Exit so that data is not accidentally overwritten!")
            exit(1)

    args.logfile = os.path.join(args.output_dir, args.logfile)

    device = torch.device(args.device)
    logger = get_logger_by_args(args)
    logger.info(args)

    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
    logger.info(f"Seed: {args.seed}")
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    # discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
    discriminator = PatchDiscriminator(in_channels=2, input_size=args.patch_size)
    discriminator = discriminator.to(device)

    generator = UNet(in_channels=1, features=args.features)
    generator = generator.to(device)
    if args.generator_weights:
        logger.debug(f"Load generator weights from {args.generator_weights}")
        generator.load_state_dict(torch.load(args.generator_weights, map_location=device))

    transform_normalization = None
    if args.input_norm == "mean":
        transform_normalization = MeanNormTransform()
    elif args.input_norm == "max":
        transform_normalization = MaxNormTransform()
    elif args.input_norm == "gaussian":
        transform_normalization = GaussianNormTransform()

    data_loaders = {}
    for split in ["train", "validation"]:
        dataset = MuMapPatchDataset(
            args.dataset_dir,
            patches_per_image=args.number_of_patches,
            patch_size=args.patch_size,
            patch_offset=args.patch_offset,
            shuffle=not args.no_shuffle,
            split_name=split,
            transform_normalization=transform_normalization,
            logger=logger,
        )
        data_loader = torch.utils.data.DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=1,
        )
        data_loaders[split] = data_loader

    training = cGANTraining(
        discriminator=discriminator,
        generator=generator,
        data_loaders=data_loaders,
        epochs=args.epochs,
        device=device,
        lr_d=0.0002,
        lr_decay_factor_d=0.99,
        lr_decay_epoch_d=1,
        lr_g=0.0002,
        lr_decay_factor_g=0.99,
        lr_decay_epoch_g=1,
        l2_weight=args.mse_loss_weight,
        gdl_weight=args.gdl_loss_weight,
        adv_weight=args.adv_loss_weight,
        snapshot_dir=args.snapshot_dir,
        snapshot_epoch=args.snapshot_epoch,
        logger=logger,
    )
    losses_d, losses_g = training.run()

    import matplotlib.pyplot as plt

    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].plot(losses_d)
    axs[0].set_title("Discriminator Loss")
    axs[0].set_xlabel("Iteration")
    axs[0].set_ylabel("Loss")
    axs[1].plot(losses_g, label="Generator")
    axs[1].set_title("Generator Loss")
    axs[1].set_xlabel("Iteration")
    axs[1].set_ylabel("Loss")
    plt.savefig("losses.png")