Newer
Older
from typing import Dict, Optional
import sys
import torch
from torch import Tensor
from mu_map.training.loss import WeightedLoss
from mu_map.logging import get_logger
# Establish convention for real and fake labels during training
LABEL_REAL = 1.0
LABEL_FAKE = 0.0
from dataclass import dataclass
@dataclass
class TrainingParams:
model: torch.nn.Module
optimizer: torch.optim.Optimizer
lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
class cGANTraining:
def __init__(
self,
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
device: torch.device,
snapshot_dir: str,
snapshot_epoch: int,
params_generator: torch.nn.Module,
params_discriminator: torch.nn.Module,
loss_func_dist: WeightedLoss,
weight_criterion_dist: float,
weight_criterion_adv: float,
logger=None,
):
self.generator = generator
self.discriminator = discriminator
self.data_loaders = data_loaders
self.epochs = epochs
self.device = device
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger if logger is not None else get_logger()
self.params_g = params_generator
self.params_d = params_discriminator
self.weight_criterion_dist = weight_criterion_dist
self.weight_criterion_adv = weight_criterion_adv
self.criterion_adv = torch.nn.MSELoss(reduction="mean")
self.criterion_dist = self.loss_func_dist
loss_val_min = sys.maxsize
for epoch in range(1, self.epochs + 1):
str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
self._train_epoch()
loss_train = self._eval_epoch("train")
logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
)
loss_val = self._eval_epoch("validation")
logger.info(
f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
)
if loss_val < loss_val_min:
loss_val_min = loss_val
logger.info(
"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
)
self.store_snapshot("val_min")
self._store_snapshot(epoch)
if self.params_d.lr_scheduler is not None:
logger.debug("Step LR scheduler of discriminator")
self.params_d.lr_scheduler.step()
if self.params_g.lr_scheduler is not None:
logger.debug("Step LR scheduler of generator")
self.params_g.lr_scheduler.step()
return loss_val
def _train_epoch(self):
logger.debug(f"Train epoch")
# setup training mode
torch.set_grad_enabled(True)
self.params_d.model.train()
self.params_g.model.train()
data_loader = self.data_loaders["train"]
for i, (recons, mu_maps) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
batch_size = recons.shape[0]
recons = recons.to(self.device)
mu_maps = mu_maps.to(self.device)
self.params_d.optimizer.zero_grad()
self.params_g.optimizer.zero_grad()
# compute fake mu maps with generator
mu_maps_fake = self.params_g.model(recons)
# compute discriminator loss for fake mu maps
inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
outputs_d_fake = self.params_d.model(
inputs_d_fake.detach()
) # note the detach, so that gradients are not computed for the generator
labels_fake = torch.full(
(outputs_d_fake.shape), LABEL_FAKE, device=self.device
)
loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake)
# compute discriminator loss for real mu maps
inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
outputs_d_real = self.params_d.model(
inputs_d_real
) # note the detach, so that gradients are not computed for the generator
labels_real = torch.full(
(outputs_d_fake.shape), LABEL_REAL, device=self.device
)
loss_d_real = self.criterion_adv(outputs_d_real, labels_real)
# update discriminator
loss_d = 0.5 * (loss_d_fake + loss_d_real)
loss_d.backward() # compute gradients
self.params_d.optimizer.step()
inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
outputs_d_fake = self.params_d.model(inputs_d_fake)
loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
loss_g = (
self.weight_criterion_adv * loss_g_adv
+ self.weight_criterion_dist * loss_g_dist
)
self.params_g.optimizer.step()
def _eval_epoch(self, epoch, split_name):
# setup evaluation mode
torch.set_grad_enabled(False)
self.discriminator = self.discriminator.eval()
self.generator = self.generator.eval()
data_loader = self.data_loaders[split_name]
loss = 0.0
updates = 0
for i, (recons, mu_maps) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
recons = recons.to(self.device)
mu_maps = mu_maps.to(self.device)
outputs = self.params_g(recons)
loss += torch.nn.functional.l1_loss(outputs, mu_maps)
updates += 1
return loss / updates
def _store_snapshot(self, epoch):
prefix = f"{epoch:0{len(str(self.epochs))}d}"
self.store_snapshot(prefix)
def store_snapshot(self, prefix: str):
snapshot_file_d = os.path.join(self.snapshot_dir, f"{prefix}_discriminator.pth")
snapshot_file_g = os.path.join(self.snapshot_dir, f"{prefix}_generator.pth")
logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
torch.save(self.discriminator.state_dict(), snapshot_file_d)
torch.save(self.generator.state_dict(), snapshot_file_g)
if __name__ == "__main__":
import argparse
import random
import sys
import numpy as np
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import PadCropTranform, SequenceTransform
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
from mu_map.models.discriminator import Discriminator, PatchDiscriminator
parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model Args
parser.add_argument(
"--features",
type=int,
nargs="+",
default=[64, 128, 256, 512],
help="number of features in the layers of the UNet structure",
)
# Dataset Args
parser.add_argument(
"--dataset_dir",
type=str,
help="the directory where the dataset for training is found",
)
parser.add_argument(
"--input_norm",
type=str,
choices=["none", "mean", "max", "gaussian"],
default="mean",
help="type of normalization applied to the reconstructions",
)
parser.add_argument(
"--patch_size",
type=int,
default=32,
help="the size of patches extracted for each reconstruction",
)
parser.add_argument(
"--patch_offset",
type=int,
default=20,
help="offset to ignore the border of the image",
)
parser.add_argument(
"--number_of_patches",
type=int,
help="number of patches extracted for each image",
)
parser.add_argument(
"--no_shuffle",
action="store_true",
help="do not shuffle patches in the dataset",
)
parser.add_argument(
"scatter_correction",
action="store_true",
help="use the scatter corrected reconstructions in the dataset",
)
# Training Args
parser.add_argument(
"--seed",
type=int,
help="seed used for random number generation",
)
parser.add_argument(
"--batch_size",
type=int,
help="the batch size used for training",
)
parser.add_argument(
"--output_dir",
type=str,
default="train_data",
help="directory in which results (snapshots and logs) of this training are saved",
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="the number of epochs for which the model is trained",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
"--dist_loss_func",
type=str,
default="l1",
help="define the loss function used as the distance loss of the generator , e.g. 0.75*l2+0.25*gdl",
"--dist_loss_weight",
default=100.0,
help="weight for the distance loss of the generator",
)
parser.add_argument(
"--adv_loss_weight",
type=float,
help="weight for the Adversarial-Loss of the generator",
)
parser.add_argument(
"--lr", type=float, default=0.001, help="the initial learning rate for training"
)
parser.add_argument(
"--decay_lr",
action="store_true",
help="decay the learning rate",
)
parser.add_argument(
"--lr_decay_factor",
type=float,
default=0.99,
help="decay factor for the learning rate",
)
parser.add_argument(
"--lr_decay_epoch",
type=int,
default=1,
help="frequency in epochs at which the learning rate is decayed",
)
parser.add_argument(
"--snapshot_dir",
type=str,
default="snapshots",
help="directory under --output_dir where snapshots are stored",
)
parser.add_argument(
"--snapshot_epoch",
type=int,
default=10,
help="frequency in epochs at which snapshots are stored",
)
parser.add_argument(
"--generator_weights",
type=str,
help="use pre-trained weights for the generator",
)
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
# Logging Args
add_logging_args(parser, defaults={"--logfile": "train.log"})
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
else:
if len(os.listdir(args.snapshot_dir)) > 0:
print(
f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
)
print(f" Exit so that data is not accidentally overwritten!")
exit(1)
args.logfile = os.path.join(args.output_dir, args.logfile)
device = torch.device(args.device)
logger = get_logger_by_args(args)
logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2 ** 32 - 1)
logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
transform_normalization = SequenceTransform(
[transform_normalization, PadCropTranform(dim=3, size=32)]
)
data_loaders = {}
for split in ["train", "validation"]:
dataset = MuMapPatchDataset(
args.dataset_dir,
patches_per_image=args.number_of_patches,
patch_size=args.patch_size,
patch_offset=args.patch_offset,
shuffle=not args.no_shuffle,
split_name=split,
transform_normalization=transform_normalization,
scatter_correction=args.scatter_correction,
logger=logger,
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
)
data_loaders[split] = data_loader
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
discriminator = discriminator.to(device)
optimizer = torch.optim.Adam(
discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999)
)
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
)
if args.decay_lr
else None
)
params_d = TrainingParams(
model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
generator = UNet(in_channels=1, features=args.features)
generator = generator.to(device)
if args.generator_weights:
logger.debug(f"Load generator weights from {args.generator_weights}")
generator.load_state_dict(
torch.load(args.generator_weights, map_location=device)
)
optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
)
if args.decay_lr
else None
)
params_g = TrainingParams(
model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
)
dist_criterion = WeightedLoss.from_str(args.dist_loss_func)
logger.debug(f"Use distance criterion: {criterion}")
training = cGANTraining(
data_loaders=data_loaders,
epochs=args.epochs,
device=device,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
logger=logger,
params_generator=params_g,
params_discriminator=params_d,
loss_func_dist=dist_criterion,
weight_criterion_dist=args.dist_loss_weight,
weight_criterion_adv=args.adv_loss_weight,