Skip to content
Snippets Groups Projects
Commit 92f48bfb authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

implement conditional GAN training

parent 26bc3d50
No related branches found
No related tags found
No related merge requests found
import os
from typing import Dict
import torch
from torch import Tensor
from mu_map.training.loss import GradientDifferenceLoss
from mu_map.logging import get_logger
# Establish convention for real and fake labels during training
LABEL_REAL = 1.0
LABEL_FAKE = 0.0
class GeneratorLoss(torch.nn.Module):
def __init__(
self, l2_weight: float = 1.0, gdl_weight: float = 1.0, adv_weight: float = 20.0
):
super().__init__()
self.l2 = torch.nn.MSELoss(reduction="mean")
self.l2_weight = l2_weight
self.gdl = GradientDifferenceLoss()
self.gdl_weight = gdl_weight
self.adv = torch.nn.MSELoss(reduction="mean")
self.adv_weight = adv_weight
def forward(
self,
mu_maps_real: Tensor,
outputs_g: Tensor,
targets_d: Tensor,
outputs_d: Tensor,
):
loss_l2 = self.l2(outputs_g, mu_maps_real)
loss_gdl = self.gdl(outputs_g, mu_maps_real)
loss_adv = self.adv(outputs_d, targets_d)
return (
self.l2_weight * loss_l2
+ self.gdl_weight * loss_gdl
+ self.adv_weight * loss_adv
)
class cGANTraining:
def __init__(
self,
generator: torch.nn.Module,
discriminator: torch.nn.Module,
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
device: torch.device,
lr_d: float,
lr_decay_factor_d: float,
lr_decay_epoch_d: int,
lr_g: float,
lr_decay_factor_g: float,
lr_decay_epoch_g: int,
l2_weight: float,
gdl_weight: float,
adv_weight: float,
snapshot_dir: str,
snapshot_epoch: int,
logger=None,
):
self.generator = generator
self.discriminator = discriminator
self.data_loaders = data_loaders
self.epochs = epochs
self.device = device
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger if logger is not None else get_logger()
self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)
self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g)
self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR(
self.optimizer_d,
step_size=lr_decay_epoch_d,
gamma=lr_decay_factor_d,
)
self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR(
self.optimizer_g,
step_size=lr_decay_epoch_g,
gamma=lr_decay_factor_g,
)
self.criterion_d = torch.nn.MSELoss(reduction="mean")
self.criterion_g = GeneratorLoss(
l2_weight=l2_weight, gdl_weight=gdl_weight, adv_weight=adv_weight
)
def run(self):
losses_d = []
losses_g = []
for epoch in range(1, self.epochs + 1):
logger.debug(
f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
)
_losses_d, _losses_g = self._train_epoch()
losses_d.extend(_losses_d)
losses_g.extend(_losses_g)
self._eval_epoch(epoch, "train")
self._eval_epoch(epoch, "validation")
self.lr_scheduler_d.step()
self.lr_scheduler_g.step()
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(epoch)
logger.debug(
f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
)
return losses_d, losses_g
def _train_epoch(self):
logger.debug(f"Train epoch")
torch.set_grad_enabled(True)
self.discriminator = self.discriminator.train()
self.generator = self.generator.train()
losses_d = []
losses_g = []
data_loader = self.data_loaders["train"]
for i, (recons, mu_maps) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
recons = recons.to(self.device)
mu_maps = mu_maps.to(self.device)
loss_d_real, loss_d_fake, loss_g = self._step(recons, mu_maps)
losses_d.append(loss_d_real + loss_d_fake)
losses_g.append(loss_g)
return losses_d, losses_g
def _step(self, recons, mu_maps_real):
batch_size = recons.shape[0]
self.optimizer_d.zero_grad()
self.optimizer_g.zero_grad()
labels_real = torch.full((batch_size, 1), LABEL_REAL, device=self.device)
labels_fake = torch.full((batch_size, 1), LABEL_FAKE, device=self.device)
with torch.set_grad_enabled(True):
# compute fake mu maps with generator
mu_maps_fake = self.generator(recons)
# update discriminator based on real mu maps
outputs_d = self.discriminator(mu_maps_real)
loss_d_real = self.criterion_d(outputs_d, labels_real)
loss_d_real.backward() # compute gradients
# update discriminator based on fake mu maps
outputs_d = self.discriminator(
mu_maps_fake.detach()
) # note the detach, so that gradients are not computed for the generator
loss_d_fake = self.criterion_d(outputs_d, labels_fake)
loss_d_fake.backward() # compute gradients
self.optimizer_d.step() # update discriminator based on gradients
# update generator
outputs_d = self.discriminator(mu_maps_fake)
loss_g = self.criterion_g(mu_maps_real, mu_maps_fake, labels_real, outputs_d)
loss_g.backward()
self.optimizer_g.step()
return loss_d_real.item(), loss_d_fake.item(), loss_g.item()
def _eval_epoch(self, epoch, split_name):
logger.debug(f"Evaluate epoch on split {split_name}")
torch.set_grad_enabled(False)
self.discriminator = self.discriminator.eval()
self.generator = self.generator.eval()
loss = 0.0
updates = 0
data_loader = self.data_loaders[split_name]
for i, (recons, mu_maps) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
recons = recons.to(self.device)
mu_maps = mu_maps.to(self.device)
outputs = self.generator(recons)
loss += torch.nn.functional.l1_loss(outputs, mu_maps)
updates += 1
loss = loss / updates
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss {split_name}: {loss:.6f}"
)
def store_snapshot(self, epoch):
snapshot_file_d = f"{epoch:0{len(str(self.epochs))}d}_discriminator.pth"
snapshot_file_d = os.path.join(self.snapshot_dir, snapshot_file_d)
snapshot_file_g = f"{epoch:0{len(str(self.epochs))}d}_generator.pth"
snapshot_file_g = os.path.join(self.snapshot_dir, snapshot_file_g)
logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
torch.save(self.discriminator.state_dict(), snapshot_file_d)
torch.save(self.generator.state_dict(), snapshot_file_g)
if __name__ == "__main__":
import argparse
import random
import sys
import numpy as np
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import ScaleTransform
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
from mu_map.models.discriminator import Discriminator
parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model Args
parser.add_argument(
"--features",
type=int,
nargs="+",
default=[64, 128, 256, 512],
help="number of features in the layers of the UNet structure",
)
# Dataset Args
parser.add_argument(
"--dataset_dir",
type=str,
default="data/initial/",
help="the directory where the dataset for training is found",
)
parser.add_argument(
"--output_scale",
type=float,
default=1.0,
help="scale the attenuation map by this coefficient",
)
parser.add_argument(
"--input_norm",
type=str,
choices=["none", "mean", "max", "gaussian"],
default="mean",
help="type of normalization applied to the reconstructions",
)
parser.add_argument(
"--patch_size",
type=int,
default=32,
help="the size of patches extracted for each reconstruction",
)
parser.add_argument(
"--patch_offset",
type=int,
default=20,
help="offset to ignore the border of the image",
)
parser.add_argument(
"--number_of_patches",
type=int,
default=1,
help="number of patches extracted for each image",
)
parser.add_argument(
"--no_shuffle",
action="store_true",
help="do not shuffle patches in the dataset",
)
# Training Args
parser.add_argument(
"--seed",
type=int,
help="seed used for random number generation",
)
parser.add_argument(
"--batch_size",
type=int,
default=8,
help="the batch size used for training",
)
parser.add_argument(
"--output_dir",
type=str,
default="train_data",
help="directory in which results (snapshots and logs) of this training are saved",
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="the number of epochs for which the model is trained",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--lr", type=float, default=0.001, help="the initial learning rate for training"
)
parser.add_argument(
"--lr_decay_factor",
type=float,
default=0.99,
help="decay factor for the learning rate",
)
parser.add_argument(
"--lr_decay_epoch",
type=int,
default=1,
help="frequency in epochs at which the learning rate is decayed",
)
parser.add_argument(
"--snapshot_dir",
type=str,
default="snapshots",
help="directory under --output_dir where snapshots are stored",
)
parser.add_argument(
"--snapshot_epoch",
type=int,
default=10,
help="frequency in epochs at which snapshots are stored",
)
# Logging Args
add_logging_args(parser, defaults={"--logfile": "train.log"})
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
else:
if len(os.listdir(args.snapshot_dir)) > 0:
print(
f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
)
print(f" Exit so that data is not accidentally overwritten!")
exit(1)
args.logfile = os.path.join(args.output_dir, args.logfile)
device = torch.device(args.device)
logger = get_logger_by_args(args)
logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1)
logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
discriminator = Discriminator(in_channels=1, input_size=args.patch_size)
discriminator = discriminator.to(device)
generator = UNet(in_channels=1, features=args.features)
generator = generator.to(device)
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
transform_augmentation = ScaleTransform(scale_outputs=args.output_scale)
data_loaders = {}
for split in ["train", "validation"]:
dataset = MuMapPatchDataset(
args.dataset_dir,
patches_per_image=args.number_of_patches,
patch_size=args.patch_size,
patch_offset=args.patch_offset,
shuffle=not args.no_shuffle,
split_name=split,
transform_normalization=transform_normalization,
transform_augmentation=transform_augmentation,
logger=logger,
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
)
data_loaders[split] = data_loader
training = cGANTraining(
discriminator=discriminator,
generator=generator,
data_loaders=data_loaders,
epochs=args.epochs,
device=device,
lr_d=0.0005,
lr_decay_factor_d=0.99,
lr_decay_epoch_d=1,
lr_g=0.001,
lr_decay_factor_g=0.99,
lr_decay_epoch_g=1,
l2_weight=1.0,
gdl_weight=1.0,
adv_weight=20.0,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
logger=logger,
)
losses_d, losses_g = training.run()
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].plot(losses_d)
axs[0].set_title("Discriminator Loss")
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Loss")
axs[1].plot(losses_g, label="Generator")
axs[1].set_title("Generator Loss")
axs[1].set_xlabel("Iteration")
axs[1].set_ylabel("Loss")
plt.savefig("losses.png")
import torch
import torch.nn as nn
class GradientDifferenceLoss(nn.Module):
"""
Gradient Difference Loss (GDL) inspired by https://github.com/mmany/pytorch-GDL/blob/main/custom_loss_functions.py.
It is modified to deal with 5D tensors (batch_size, channels, z, y, x).
"""
def forward(self, inputs: torch.Tensor, targets: torch.Tensor):
gradient_diff_z = (inputs.diff(dim=2) - targets.diff(axis=2)).pow(2).sum()
gradient_diff_y = (inputs.diff(dim=3) - targets.diff(axis=3)).pow(2).sum()
gradient_diff_x = (inputs.diff(dim=4) - targets.diff(axis=4)).pow(2).sum()
gradient_diff = gradient_diff_x + gradient_diff_y + gradient_diff_z
return gradient_diff / inputs.numel()
if __name__ == "__main__":
torch.manual_seed(10)
inputs = torch.rand((4, 1, 32, 64, 64))
targets = torch.rand((4, 1, 32, 64, 64))
criterion = GradientDifferenceLoss()
loss = criterion(inputs, targets)
print(f"Loss: {loss.item():.6f}")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment