From 92f48bfbedf70cbd75aebe5363e8af68d585d811 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Wed, 5 Oct 2022 15:55:49 +0200 Subject: [PATCH] implement conditional GAN training --- mu_map/training/cgan.py | 457 ++++++++++++++++++++++++++++++++++++++++ mu_map/training/loss.py | 27 +++ 2 files changed, 484 insertions(+) create mode 100644 mu_map/training/cgan.py create mode 100644 mu_map/training/loss.py diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py new file mode 100644 index 0000000..aa93a31 --- /dev/null +++ b/mu_map/training/cgan.py @@ -0,0 +1,457 @@ +import os +from typing import Dict + +import torch +from torch import Tensor + +from mu_map.training.loss import GradientDifferenceLoss +from mu_map.logging import get_logger + +# Establish convention for real and fake labels during training +LABEL_REAL = 1.0 +LABEL_FAKE = 0.0 + + +class GeneratorLoss(torch.nn.Module): + def __init__( + self, l2_weight: float = 1.0, gdl_weight: float = 1.0, adv_weight: float = 20.0 + ): + super().__init__() + + self.l2 = torch.nn.MSELoss(reduction="mean") + self.l2_weight = l2_weight + + self.gdl = GradientDifferenceLoss() + self.gdl_weight = gdl_weight + + self.adv = torch.nn.MSELoss(reduction="mean") + self.adv_weight = adv_weight + + def forward( + self, + mu_maps_real: Tensor, + outputs_g: Tensor, + targets_d: Tensor, + outputs_d: Tensor, + ): + loss_l2 = self.l2(outputs_g, mu_maps_real) + loss_gdl = self.gdl(outputs_g, mu_maps_real) + loss_adv = self.adv(outputs_d, targets_d) + + return ( + self.l2_weight * loss_l2 + + self.gdl_weight * loss_gdl + + self.adv_weight * loss_adv + ) + + +class cGANTraining: + def __init__( + self, + generator: torch.nn.Module, + discriminator: torch.nn.Module, + data_loaders: Dict[str, torch.utils.data.DataLoader], + epochs: int, + device: torch.device, + lr_d: float, + lr_decay_factor_d: float, + lr_decay_epoch_d: int, + lr_g: float, + lr_decay_factor_g: float, + lr_decay_epoch_g: int, + l2_weight: float, + gdl_weight: float, + adv_weight: float, + snapshot_dir: str, + snapshot_epoch: int, + logger=None, + ): + self.generator = generator + self.discriminator = discriminator + + self.data_loaders = data_loaders + self.epochs = epochs + self.device = device + + self.snapshot_dir = snapshot_dir + self.snapshot_epoch = snapshot_epoch + + self.logger = logger if logger is not None else get_logger() + + self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d) + self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g) + + self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR( + self.optimizer_d, + step_size=lr_decay_epoch_d, + gamma=lr_decay_factor_d, + ) + self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR( + self.optimizer_g, + step_size=lr_decay_epoch_g, + gamma=lr_decay_factor_g, + ) + + self.criterion_d = torch.nn.MSELoss(reduction="mean") + self.criterion_g = GeneratorLoss( + l2_weight=l2_weight, gdl_weight=gdl_weight, adv_weight=adv_weight + ) + + def run(self): + losses_d = [] + losses_g = [] + for epoch in range(1, self.epochs + 1): + logger.debug( + f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..." + ) + _losses_d, _losses_g = self._train_epoch() + losses_d.extend(_losses_d) + losses_g.extend(_losses_g) + + self._eval_epoch(epoch, "train") + self._eval_epoch(epoch, "validation") + + self.lr_scheduler_d.step() + self.lr_scheduler_g.step() + + if epoch % self.snapshot_epoch == 0: + self.store_snapshot(epoch) + + logger.debug( + f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}" + ) + return losses_d, losses_g + + def _train_epoch(self): + logger.debug(f"Train epoch") + torch.set_grad_enabled(True) + + self.discriminator = self.discriminator.train() + self.generator = self.generator.train() + + losses_d = [] + losses_g = [] + + data_loader = self.data_loaders["train"] + for i, (recons, mu_maps) in enumerate(data_loader): + print( + f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", + end="\r", + ) + recons = recons.to(self.device) + mu_maps = mu_maps.to(self.device) + + loss_d_real, loss_d_fake, loss_g = self._step(recons, mu_maps) + + losses_d.append(loss_d_real + loss_d_fake) + losses_g.append(loss_g) + return losses_d, losses_g + + def _step(self, recons, mu_maps_real): + batch_size = recons.shape[0] + + self.optimizer_d.zero_grad() + self.optimizer_g.zero_grad() + + labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device) + labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device) + + with torch.set_grad_enabled(True): + # compute fake mu maps with generator + mu_maps_fake = self.generator(recons) + + # update discriminator based on real mu maps + outputs_d = self.discriminator(mu_maps_real) + loss_d_real = self.criterion_d(outputs_d, labels_real) + loss_d_real.backward() # compute gradients + # update discriminator based on fake mu maps + outputs_d = self.discriminator( + mu_maps_fake.detach() + ) # note the detach, so that gradients are not computed for the generator + loss_d_fake = self.criterion_d(outputs_d, labels_fake) + loss_d_fake.backward() # compute gradients + self.optimizer_d.step() # update discriminator based on gradients + + # update generator + outputs_d = self.discriminator(mu_maps_fake) + loss_g = self.criterion_g(mu_maps_real, mu_maps_fake, labels_real, outputs_d) + loss_g.backward() + self.optimizer_g.step() + + return loss_d_real.item(), loss_d_fake.item(), loss_g.item() + + def _eval_epoch(self, epoch, split_name): + logger.debug(f"Evaluate epoch on split {split_name}") + torch.set_grad_enabled(False) + + self.discriminator = self.discriminator.eval() + self.generator = self.generator.eval() + + loss = 0.0 + updates = 0 + + data_loader = self.data_loaders[split_name] + for i, (recons, mu_maps) in enumerate(data_loader): + print( + f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", + end="\r", + ) + recons = recons.to(self.device) + mu_maps = mu_maps.to(self.device) + + outputs = self.generator(recons) + + loss += torch.nn.functional.l1_loss(outputs, mu_maps) + updates += 1 + loss = loss / updates + logger.info( + f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss {split_name}: {loss:.6f}" + ) + + def store_snapshot(self, epoch): + snapshot_file_d = f"{epoch:0{len(str(self.epochs))}d}_discriminator.pth" + snapshot_file_d = os.path.join(self.snapshot_dir, snapshot_file_d) + + snapshot_file_g = f"{epoch:0{len(str(self.epochs))}d}_generator.pth" + snapshot_file_g = os.path.join(self.snapshot_dir, snapshot_file_g) + logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}") + torch.save(self.discriminator.state_dict(), snapshot_file_d) + torch.save(self.generator.state_dict(), snapshot_file_g) + + +if __name__ == "__main__": + import argparse + import random + import sys + + import numpy as np + + from mu_map.dataset.patches import MuMapPatchDataset + from mu_map.dataset.normalization import ( + MeanNormTransform, + MaxNormTransform, + GaussianNormTransform, + ) + from mu_map.dataset.transform import ScaleTransform + from mu_map.logging import add_logging_args, get_logger_by_args + from mu_map.models.unet import UNet + from mu_map.models.discriminator import Discriminator + + parser = argparse.ArgumentParser( + description="Train a UNet model to predict μ-maps from reconstructed scatter images", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Model Args + parser.add_argument( + "--features", + type=int, + nargs="+", + default=[64, 128, 256, 512], + help="number of features in the layers of the UNet structure", + ) + + # Dataset Args + parser.add_argument( + "--dataset_dir", + type=str, + default="data/initial/", + help="the directory where the dataset for training is found", + ) + parser.add_argument( + "--output_scale", + type=float, + default=1.0, + help="scale the attenuation map by this coefficient", + ) + parser.add_argument( + "--input_norm", + type=str, + choices=["none", "mean", "max", "gaussian"], + default="mean", + help="type of normalization applied to the reconstructions", + ) + parser.add_argument( + "--patch_size", + type=int, + default=32, + help="the size of patches extracted for each reconstruction", + ) + parser.add_argument( + "--patch_offset", + type=int, + default=20, + help="offset to ignore the border of the image", + ) + parser.add_argument( + "--number_of_patches", + type=int, + default=1, + help="number of patches extracted for each image", + ) + parser.add_argument( + "--no_shuffle", + action="store_true", + help="do not shuffle patches in the dataset", + ) + + # Training Args + parser.add_argument( + "--seed", + type=int, + help="seed used for random number generation", + ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help="the batch size used for training", + ) + parser.add_argument( + "--output_dir", + type=str, + default="train_data", + help="directory in which results (snapshots and logs) of this training are saved", + ) + parser.add_argument( + "--epochs", + type=int, + default=100, + help="the number of epochs for which the model is trained", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0" if torch.cuda.is_available() else "cpu", + help="the device (cpu or gpu) with which the training is performed", + ) + parser.add_argument( + "--lr", type=float, default=0.001, help="the initial learning rate for training" + ) + parser.add_argument( + "--lr_decay_factor", + type=float, + default=0.99, + help="decay factor for the learning rate", + ) + parser.add_argument( + "--lr_decay_epoch", + type=int, + default=1, + help="frequency in epochs at which the learning rate is decayed", + ) + parser.add_argument( + "--snapshot_dir", + type=str, + default="snapshots", + help="directory under --output_dir where snapshots are stored", + ) + parser.add_argument( + "--snapshot_epoch", + type=int, + default=10, + help="frequency in epochs at which snapshots are stored", + ) + + # Logging Args + add_logging_args(parser, defaults={"--logfile": "train.log"}) + + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + + args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) + if not os.path.exists(args.snapshot_dir): + os.mkdir(args.snapshot_dir) + else: + if len(os.listdir(args.snapshot_dir)) > 0: + print( + f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!" + ) + print(f" Exit so that data is not accidentally overwritten!") + exit(1) + + args.logfile = os.path.join(args.output_dir, args.logfile) + + device = torch.device(args.device) + logger = get_logger_by_args(args) + logger.info(args) + + args.seed = args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1) + logger.info(f"Seed: {args.seed}") + random.seed(args.seed) + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + discriminator = Discriminator(in_channels=1, input_size=args.patch_size) + discriminator = discriminator.to(device) + + generator = UNet(in_channels=1, features=args.features) + generator = generator.to(device) + + transform_normalization = None + if args.input_norm == "mean": + transform_normalization = MeanNormTransform() + elif args.input_norm == "max": + transform_normalization = MaxNormTransform() + elif args.input_norm == "gaussian": + transform_normalization = GaussianNormTransform() + + transform_augmentation = ScaleTransform(scale_outputs=args.output_scale) + + data_loaders = {} + for split in ["train", "validation"]: + dataset = MuMapPatchDataset( + args.dataset_dir, + patches_per_image=args.number_of_patches, + patch_size=args.patch_size, + patch_offset=args.patch_offset, + shuffle=not args.no_shuffle, + split_name=split, + transform_normalization=transform_normalization, + transform_augmentation=transform_augmentation, + logger=logger, + ) + data_loader = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=args.batch_size, + shuffle=True, + pin_memory=True, + num_workers=1, + ) + data_loaders[split] = data_loader + + training = cGANTraining( + discriminator=discriminator, + generator=generator, + data_loaders=data_loaders, + epochs=args.epochs, + device=device, + lr_d=0.0005, + lr_decay_factor_d=0.99, + lr_decay_epoch_d=1, + lr_g=0.001, + lr_decay_factor_g=0.99, + lr_decay_epoch_g=1, + l2_weight=1.0, + gdl_weight=1.0, + adv_weight=20.0, + snapshot_dir=args.snapshot_dir, + snapshot_epoch=args.snapshot_epoch, + logger=logger, + ) + losses_d, losses_g = training.run() + + import matplotlib.pyplot as plt + + fig, axs = plt.subplots(1, 2, figsize=(10, 5)) + axs[0].plot(losses_d) + axs[0].set_title("Discriminator Loss") + axs[0].set_xlabel("Iteration") + axs[0].set_ylabel("Loss") + axs[1].plot(losses_g, label="Generator") + axs[1].set_title("Generator Loss") + axs[1].set_xlabel("Iteration") + axs[1].set_ylabel("Loss") + plt.savefig("losses.png") diff --git a/mu_map/training/loss.py b/mu_map/training/loss.py new file mode 100644 index 0000000..b62c07e --- /dev/null +++ b/mu_map/training/loss.py @@ -0,0 +1,27 @@ +import torch +import torch.nn as nn + + +class GradientDifferenceLoss(nn.Module): + """ + Gradient Difference Loss (GDL) inspired by https://github.com/mmany/pytorch-GDL/blob/main/custom_loss_functions.py. + It is modified to deal with 5D tensors (batch_size, channels, z, y, x). + """ + + def forward(self, inputs: torch.Tensor, targets: torch.Tensor): + gradient_diff_z = (inputs.diff(dim=2) - targets.diff(axis=2)).pow(2).sum() + gradient_diff_y = (inputs.diff(dim=3) - targets.diff(axis=3)).pow(2).sum() + gradient_diff_x = (inputs.diff(dim=4) - targets.diff(axis=4)).pow(2).sum() + + gradient_diff = gradient_diff_x + gradient_diff_y + gradient_diff_z + return gradient_diff / inputs.numel() + + +if __name__ == "__main__": + torch.manual_seed(10) + inputs = torch.rand((4, 1, 32, 64, 64)) + targets = torch.rand((4, 1, 32, 64, 64)) + + criterion = GradientDifferenceLoss() + loss = criterion(inputs, targets) + print(f"Loss: {loss.item():.6f}") -- GitLab