-
Tamino Huxohl authoredTamino Huxohl authored
cgan.py 15.78 KiB
import os
from typing import Dict
import torch
from torch import Tensor
from mu_map.training.loss import GradientDifferenceLoss
from mu_map.logging import get_logger
# Establish convention for real and fake labels during training
LABEL_REAL = 1.0
LABEL_FAKE = 0.0
# class GeneratorLoss(torch.nn.Module):
# def __init__(
# self,
# # l2_weight: float = 1.0,
# # gdl_weight: float = 1.0,
# # adv_weight: float = 20.0,
# # logger=None,
# ):
# super().__init__()
# # self.l2 = torch.nn.MSELoss(reduction="mean")
# self.l2 = torch.nn.L1Loss(reduction="mean")
# self.l2_weight = l2_weight
# self.gdl = GradientDifferenceLoss()
# self.gdl_weight = gdl_weight
# self.adv = torch.nn.MSELoss(reduction="mean")
# self.adv_weight = adv_weight
# if logger:
# logger.debug(f"GeneratorLoss: {self}")
# def __repr__(self):
# return f"{self.l2_weight:.3f} * MSELoss + {self.gdl_weight:.3f} * GDLLoss + {self.adv_weight:.3f} * AdversarialLoss"
# def forward(
# self,
# mu_maps_real: Tensor,
# outputs_g: Tensor,
# targets_d: Tensor,
# outputs_d: Tensor,
# ):
# loss_l2 = self.l2(outputs_g, mu_maps_real)
# loss_gdl = self.gdl(outputs_g, mu_maps_real)
# loss_adv = self.adv(outputs_d, targets_d)
# return (
# self.l2_weight * loss_l2
# + self.gdl_weight * loss_gdl
# + self.adv_weight * loss_adv
# )
class cGANTraining:
def __init__(
self,
generator: torch.nn.Module,
discriminator: torch.nn.Module,
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
device: torch.device,
lr_d: float,
lr_decay_factor_d: float,
lr_decay_epoch_d: int,
lr_g: float,
lr_decay_factor_g: float,
lr_decay_epoch_g: int,
l2_weight: float,
gdl_weight: float,
adv_weight: float,
snapshot_dir: str,
snapshot_epoch: int,
logger=None,
):
self.generator = generator
self.discriminator = discriminator
self.data_loaders = data_loaders
self.epochs = epochs
self.device = device
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger if logger is not None else get_logger()
self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d, betas=(0.5, 0.999))
self.optimizer_g = torch.optim.Adam(self.generator.parameters(), lr=lr_g, betas=(0.5, 0.999))
# self.lr_scheduler_d = torch.optim.lr_scheduler.StepLR(
# self.optimizer_d,
# step_size=lr_decay_epoch_d,
# gamma=lr_decay_factor_d,
# )
# self.lr_scheduler_g = torch.optim.lr_scheduler.StepLR(
# self.optimizer_g,
# step_size=lr_decay_epoch_g,
# gamma=lr_decay_factor_g,
# )
self.criterion_d = torch.nn.MSELoss(reduction="mean")
# self.criterion_g = GeneratorLoss(
# l2_weight=l2_weight,
# gdl_weight=gdl_weight,
# adv_weight=adv_weight,
# logger=self.logger,
# )
self.criterion_l1 = torch.nn.L1Loss(reduction="mean")
def run(self):
losses_d = []
losses_g = []
for epoch in range(1, self.epochs + 1):
logger.debug(
f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
)
_losses_d, _losses_g = self._train_epoch()
losses_d.extend(_losses_d)
losses_g.extend(_losses_g)
self._eval_epoch(epoch, "train")
self._eval_epoch(epoch, "validation")
# self.lr_scheduler_d.step()
# self.lr_scheduler_g.step()
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(epoch)
logger.debug(
f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
)
return losses_d, losses_g
def _train_epoch(self):
logger.debug(f"Train epoch")
torch.set_grad_enabled(True)
self.discriminator = self.discriminator.train()
self.generator = self.generator.train()
losses_d = []
losses_g = []
data_loader = self.data_loaders["train"]
for i, (recons, mu_maps) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
recons = recons.to(self.device)
mu_maps = mu_maps.to(self.device)
loss_d_real, loss_d_fake, loss_g = self._step(recons, mu_maps)
losses_d.append(loss_d_real + loss_d_fake)
losses_g.append(loss_g)
return losses_d, losses_g
def _step(self, recons, mu_maps_real):
batch_size = recons.shape[0]
with torch.set_grad_enabled(True):
self.optimizer_d.zero_grad()
self.optimizer_g.zero_grad()
# compute fake mu maps with generator
mu_maps_fake = self.generator(recons)
# compute discriminator loss for fake mu maps
inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
outputs_d_fake = self.discriminator(inputs_d_fake.detach()) # note the detach, so that gradients are not computed for the generator
labels_fake = torch.full((outputs_d_fake.shape), LABEL_FAKE, device=self.device)
loss_d_fake = self.criterion_d(outputs_d_fake, labels_fake)
# compute discriminator loss for real mu maps
inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
outputs_d_real = self.discriminator(inputs_d_real) # note the detach, so that gradients are not computed for the generator
labels_real = torch.full((outputs_d_fake.shape), LABEL_REAL, device=self.device)
loss_d_real = self.criterion_d(outputs_d_real, labels_real)
# update discriminator
loss_d = 0.5 * (loss_d_fake + loss_d_real)
loss_d.backward() # compute gradients
self.optimizer_d.step()
# update generator
inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
outputs_d_fake = self.discriminator(inputs_d_fake)
loss_g_adv = self.criterion_d(outputs_d_fake, labels_real)
loss_g_l1 = self.criterion_l1(mu_maps_fake, mu_maps_real)
loss_g = loss_g_adv + 100.0 * loss_g_l1
loss_g.backward()
self.optimizer_g.step()
return loss_d_real.item(), loss_d_fake.item(), loss_g.item()
def _eval_epoch(self, epoch, split_name):
logger.debug(f"Evaluate epoch on split {split_name}")
torch.set_grad_enabled(False)
self.discriminator = self.discriminator.eval()
self.generator = self.generator.eval()
loss = 0.0
updates = 0
data_loader = self.data_loaders[split_name]
for i, (recons, mu_maps) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
recons = recons.to(self.device)
mu_maps = mu_maps.to(self.device)
outputs = self.generator(recons)
loss += torch.nn.functional.l1_loss(outputs, mu_maps)
updates += 1
loss = loss / updates
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss {split_name}: {loss:.6f}"
)
def store_snapshot(self, epoch):
snapshot_file_d = f"{epoch:0{len(str(self.epochs))}d}_discriminator.pth"
snapshot_file_d = os.path.join(self.snapshot_dir, snapshot_file_d)
snapshot_file_g = f"{epoch:0{len(str(self.epochs))}d}_generator.pth"
snapshot_file_g = os.path.join(self.snapshot_dir, snapshot_file_g)
logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
torch.save(self.discriminator.state_dict(), snapshot_file_d)
torch.save(self.generator.state_dict(), snapshot_file_g)
if __name__ == "__main__":
import argparse
import random
import sys
import numpy as np
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import ScaleTransform
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
from mu_map.models.discriminator import Discriminator, PatchDiscriminator
parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model Args
parser.add_argument(
"--features",
type=int,
nargs="+",
default=[64, 128, 256, 512],
help="number of features in the layers of the UNet structure",
)
# Dataset Args
parser.add_argument(
"--dataset_dir",
type=str,
default="data/initial/",
help="the directory where the dataset for training is found",
)
parser.add_argument(
"--input_norm",
type=str,
choices=["none", "mean", "max", "gaussian"],
default="mean",
help="type of normalization applied to the reconstructions",
)
parser.add_argument(
"--patch_size",
type=int,
default=32,
help="the size of patches extracted for each reconstruction",
)
parser.add_argument(
"--patch_offset",
type=int,
default=20,
help="offset to ignore the border of the image",
)
parser.add_argument(
"--number_of_patches",
type=int,
default=100,
help="number of patches extracted for each image",
)
parser.add_argument(
"--no_shuffle",
action="store_true",
help="do not shuffle patches in the dataset",
)
# Training Args
parser.add_argument(
"--seed",
type=int,
help="seed used for random number generation",
)
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="the batch size used for training",
)
parser.add_argument(
"--output_dir",
type=str,
default="train_data",
help="directory in which results (snapshots and logs) of this training are saved",
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="the number of epochs for which the model is trained",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--mse_loss_weight",
type=float,
default=1.0,
help="weight for the L2-Loss of the generator",
)
parser.add_argument(
"--gdl_loss_weight",
type=float,
default=1.0,
help="weight for the Gradient-Difference-Loss of the generator",
)
parser.add_argument(
"--adv_loss_weight",
type=float,
default=20.0,
help="weight for the Adversarial-Loss of the generator",
)
parser.add_argument(
"--lr", type=float, default=0.001, help="the initial learning rate for training"
)
parser.add_argument(
"--lr_decay_factor",
type=float,
default=0.99,
help="decay factor for the learning rate",
)
parser.add_argument(
"--lr_decay_epoch",
type=int,
default=1,
help="frequency in epochs at which the learning rate is decayed",
)
parser.add_argument(
"--snapshot_dir",
type=str,
default="snapshots",
help="directory under --output_dir where snapshots are stored",
)
parser.add_argument(
"--snapshot_epoch",
type=int,
default=10,
help="frequency in epochs at which snapshots are stored",
)
parser.add_argument(
"--generator_weights",
type=str,
help="use pre-trained weights for the generator",
)
# Logging Args
add_logging_args(parser, defaults={"--logfile": "train.log"})
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
else:
if len(os.listdir(args.snapshot_dir)) > 0:
print(
f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
)
print(f" Exit so that data is not accidentally overwritten!")
exit(1)
args.logfile = os.path.join(args.output_dir, args.logfile)
device = torch.device(args.device)
logger = get_logger_by_args(args)
logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
# discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
discriminator = PatchDiscriminator(in_channels=2, input_size=args.patch_size)
discriminator = discriminator.to(device)
generator = UNet(in_channels=1, features=args.features)
generator = generator.to(device)
if args.generator_weights:
logger.debug(f"Load generator weights from {args.generator_weights}")
generator.load_state_dict(torch.load(args.generator_weights, map_location=device))
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
data_loaders = {}
for split in ["train", "validation"]:
dataset = MuMapPatchDataset(
args.dataset_dir,
patches_per_image=args.number_of_patches,
patch_size=args.patch_size,
patch_offset=args.patch_offset,
shuffle=not args.no_shuffle,
split_name=split,
transform_normalization=transform_normalization,
logger=logger,
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
)
data_loaders[split] = data_loader
training = cGANTraining(
discriminator=discriminator,
generator=generator,
data_loaders=data_loaders,
epochs=args.epochs,
device=device,
lr_d=0.0002,
lr_decay_factor_d=0.99,
lr_decay_epoch_d=1,
lr_g=0.0002,
lr_decay_factor_g=0.99,
lr_decay_epoch_g=1,
l2_weight=args.mse_loss_weight,
gdl_weight=args.gdl_loss_weight,
adv_weight=args.adv_loss_weight,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
logger=logger,
)
losses_d, losses_g = training.run()
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].plot(losses_d)
axs[0].set_title("Discriminator Loss")
axs[0].set_xlabel("Iteration")
axs[0].set_ylabel("Loss")
axs[1].plot(losses_g, label="Generator")
axs[1].set_title("Generator Loss")
axs[1].set_xlabel("Iteration")
axs[1].set_ylabel("Loss")
plt.savefig("losses.png")