From 730f362325d3bdb34132c202e6e43db361ef5440 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Fri, 6 Jan 2023 11:59:18 +0100 Subject: [PATCH] replace old cgan with new cgan --- mu_map/training/cgan.py | 320 +++++++++++++----------------- mu_map/training/cgan2.py | 412 --------------------------------------- 2 files changed, 137 insertions(+), 595 deletions(-) delete mode 100644 mu_map/training/cgan2.py diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py index c2c9f16..8880c46 100644 --- a/mu_map/training/cgan.py +++ b/mu_map/training/cgan.py @@ -1,51 +1,91 @@ -from dataclasses import dataclass -import os -from typing import Dict, Optional -import sys +from logging import Logger +from typing import Optional import torch -from torch import Tensor +from mu_map.dataset.default import MuMapDataset +from mu_map.training.lib import TrainingParams, AbstractTraining from mu_map.training.loss import WeightedLoss -from mu_map.logging import get_logger + # Establish convention for real and fake labels during training LABEL_REAL = 1.0 LABEL_FAKE = 0.0 -@dataclass -class TrainingParams: - model: torch.nn.Module - optimizer: torch.optim.Optimizer - lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] +class DiscriminatorParams(TrainingParams): + """ + Wrap training parameters to always carry the name 'Discriminator'. + """ + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], + ): + super().__init__( + name="Discriminator", + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + +class GeneratorParams(TrainingParams): + """ + Wrap training parameters to always carry the name 'Generator'. + """ + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], + ): + super().__init__( + name="Generator", + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + ) + -class cGANTraining: +class cGANTraining(AbstractTraining): + """ + Implementation of a conditional generative adversarial network training. + """ def __init__( self, - data_loaders: Dict[str, torch.utils.data.DataLoader], epochs: int, + dataset: MuMapDataset, + batch_size: int, device: torch.device, snapshot_dir: str, snapshot_epoch: int, - params_generator: torch.nn.Module, - params_discriminator: torch.nn.Module, + params_generator: GeneratorParams, + params_discriminator: DiscriminatorParams, loss_func_dist: WeightedLoss, weight_criterion_dist: float, weight_criterion_adv: float, - logger=None, + logger: Optional[Logger] = None, ): - self.data_loaders = data_loaders - self.epochs = epochs - self.device = device + """ + :param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator + :param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator + :param loss_func_dist: distance loss function for the generator + :param weight_criterion_dist: weight of the distance loss when training the generator + :param weight_criterion_adv: weight of the adversarial loss when training the generator + """ + super().__init__( + epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger + ) + self.training_params.append(params_generator) + self.training_params.append(params_discriminator) - self.snapshot_dir = snapshot_dir - self.snapshot_epoch = snapshot_epoch - self.logger = logger if logger is not None else get_logger(name=cGANTraining.__name__) + self.generator = params_generator.model + self.discriminator = params_discriminator.model - self.params_g = params_generator - self.params_d = params_discriminator + self.optim_g = params_generator.optimizer + self.optim_d = params_discriminator.optimizer self.weight_criterion_dist = weight_criterion_dist self.weight_criterion_adv = weight_criterion_adv @@ -53,134 +93,66 @@ class cGANTraining: self.criterion_adv = torch.nn.MSELoss(reduction="mean") self.criterion_dist = loss_func_dist - def run(self): - loss_val_min = sys.maxsize - for epoch in range(1, self.epochs + 1): - str_epoch = f"{str(epoch):>{len(str(self.epochs))}}" - self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...") - - self._train_epoch() - loss_train = self._eval_epoch("train") - self.logger.info( - f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}" - ) - loss_val = self._eval_epoch("validation") - self.logger.info( - f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}" - ) - - if loss_val < loss_val_min: - loss_val_min = loss_val - self.logger.info( - f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss" - ) - self.store_snapshot("val_min") - if epoch % self.snapshot_epoch == 0: - self._store_snapshot(epoch) - - if self.params_d.lr_scheduler is not None: - self.params_d.lr_scheduler.step() - if self.params_g.lr_scheduler is not None: - self.params_g.lr_scheduler.step() - return loss_val_min - - def _train_epoch(self): - # setup training mode - torch.set_grad_enabled(True) - self.params_d.model.train() - self.params_g.model.train() - - data_loader = self.data_loaders["train"] - for i, (recons, mu_maps_real) in enumerate(data_loader): - print( - f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", - end="\r", - ) - batch_size = recons.shape[0] - - recons = recons.to(self.device) - mu_maps_real = mu_maps_real.to(self.device) - - self.params_d.optimizer.zero_grad() - self.params_g.optimizer.zero_grad() - - # compute fake mu maps with generator - mu_maps_fake = self.params_g.model(recons) - - # compute discriminator loss for fake mu maps - inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) - outputs_d_fake = self.params_d.model( - inputs_d_fake.detach() - ) # note the detach, so that gradients are not computed for the generator - labels_fake = torch.full( - (outputs_d_fake.shape), LABEL_FAKE, device=self.device - ) - loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake) - - # compute discriminator loss for real mu maps - inputs_d_real = torch.cat((recons, mu_maps_real), dim=1) - outputs_d_real = self.params_d.model( - inputs_d_real - ) # note the detach, so that gradients are not computed for the generator - labels_real = torch.full( - (outputs_d_fake.shape), LABEL_REAL, device=self.device - ) - loss_d_real = self.criterion_adv(outputs_d_real, labels_real) - - # update discriminator - loss_d = 0.5 * (loss_d_fake + loss_d_real) - loss_d.backward() # compute gradients - self.params_d.optimizer.step() - - # update generator - inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) - outputs_d_fake = self.params_d.model(inputs_d_fake) - loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real) - loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real) - loss_g = ( - self.weight_criterion_adv * loss_g_adv - + self.weight_criterion_dist * loss_g_dist - ) - loss_g.backward() - self.params_g.optimizer.step() - - def _eval_epoch(self, split_name): - # setup evaluation mode - torch.set_grad_enabled(False) - self.params_d.model = self.params_d.model.eval() - self.params_g.model = self.params_g.model.eval() - data_loader = self.data_loaders[split_name] - - loss = 0.0 - updates = 0 - for i, (recons, mu_maps) in enumerate(data_loader): - print( - f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", - end="\r", - ) - recons = recons.to(self.device) - mu_maps = mu_maps.to(self.device) - - outputs = self.params_g.model(recons) - - loss += torch.nn.functional.l1_loss(outputs, mu_maps) - updates += 1 - return loss / updates + def _after_train_batch(self): + """ + Overwrite calling step on all optimizers as this needs to be done + separately for the generator and discriminator during the training of + a batch. + """ + pass + + def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: + mu_maps_real = mu_maps # rename real mu maps for clarification + # compute fake mu maps with generator + mu_maps_fake = self.generator(recons) + + # prepare inputs for the discriminator + inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) + inputs_d_real = torch.cat((recons, mu_maps_real), dim=1) + + # prepare labels/targets for the discriminator + labels_fake = torch.full(self.discriminator.get_output_shape(inputs_d_fake.shape), LABEL_FAKE, device=self.device) + labels_real = torch.full(self.discriminator.get_output_shape(inputs_d_real.shape), LABEL_REAL, device=self.device) + + # ======================= Discriminator ===================================== + # compute discriminator loss for fake mu maps + # detach is called so that gradients are not computed for the generator + outputs_d_fake = self.discriminator(inputs_d_fake.detach()) + loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake) + + # compute discriminator loss for real mu maps + outputs_d_real = self.discriminator(inputs_d_real) + loss_d_real = self.criterion_adv(outputs_d_real, labels_real) + + # update discriminator + loss_d = 0.5 * (loss_d_fake + loss_d_real) + loss_d.backward() # compute gradients + self.optim_d.step() + # =========================================================================== + + # ======================= Generator ========================================= + outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach + loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real) + loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real) + loss_g = ( + self.weight_criterion_adv * loss_g_adv + + self.weight_criterion_dist * loss_g_dist + ) + loss_g.backward() + self.optim_g.step() + # =========================================================================== - def _store_snapshot(self, epoch): - prefix = f"{epoch:0{len(str(self.epochs))}d}" - self.store_snapshot(prefix) + return loss_g.item() - def store_snapshot(self, prefix: str): - snapshot_file_d = os.path.join(self.snapshot_dir, f"{prefix}_discriminator.pth") - snapshot_file_g = os.path.join(self.snapshot_dir, f"{prefix}_generator.pth") - self.logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}") - torch.save(self.params_d.model.state_dict(), snapshot_file_d) - torch.save(self.params_g.model.state_dict(), snapshot_file_g) + def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: + mu_maps_fake = self.generator(recons) + loss = torch.nn.functional.l1_loss(mu_maps_fake, mu_maps) + return loss.item() if __name__ == "__main__": import argparse + import os import random import sys @@ -198,7 +170,7 @@ if __name__ == "__main__": from mu_map.models.discriminator import Discriminator, PatchDiscriminator parser = argparse.ArgumentParser( - description="Train a UNet model to predict μ-maps from reconstructed scatter images", + description="Train a UNet model to predict μ-maps from reconstructed images", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) @@ -334,11 +306,6 @@ if __name__ == "__main__": default=10, help="frequency in epochs at which snapshots are stored", ) - parser.add_argument( - "--generator_weights", - type=str, - help="use pre-trained weights for the generator", - ) # Logging Args add_logging_args(parser, defaults={"--logfile": "train.log"}) @@ -361,10 +328,11 @@ if __name__ == "__main__": args.logfile = os.path.join(args.output_dir, args.logfile) - device = torch.device(args.device) logger = get_logger_by_args(args) logger.info(args) + device = torch.device(args.device) + args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {args.seed}") random.seed(args.seed) @@ -382,29 +350,19 @@ if __name__ == "__main__": [transform_normalization, PadCropTranform(dim=3, size=32)] ) - data_loaders = {} - for split in ["train", "validation"]: - dataset = MuMapPatchDataset( - args.dataset_dir, - patches_per_image=args.number_of_patches, - patch_size=args.patch_size, - patch_offset=args.patch_offset, - shuffle=not args.no_shuffle, - split_name=split, - transform_normalization=transform_normalization, - scatter_correction=args.scatter_correction, - logger=logger, - ) - data_loader = torch.utils.data.DataLoader( - dataset=dataset, - batch_size=args.batch_size, - shuffle=True, - pin_memory=True, - num_workers=1, - ) - data_loaders[split] = data_loader + dataset = MuMapPatchDataset( + args.dataset_dir, + patches_per_image=args.number_of_patches, + patch_size=args.patch_size, + patch_offset=args.patch_offset, + shuffle=not args.no_shuffle, + transform_normalization=transform_normalization, + scatter_correction=args.scatter_correction, + logger=logger, + ) discriminator = Discriminator(in_channels=2, input_size=args.patch_size) + # discriminator = PatchDiscriminator(in_channels=2) discriminator = discriminator.to(device) optimizer = torch.optim.Adam( discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999) @@ -416,17 +374,12 @@ if __name__ == "__main__": if args.decay_lr else None ) - params_d = TrainingParams( + params_d = DiscriminatorParams( model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler ) generator = UNet(in_channels=1, features=args.features) generator = generator.to(device) - if args.generator_weights: - logger.debug(f"Load generator weights from {args.generator_weights}") - generator.load_state_dict( - torch.load(args.generator_weights, map_location=device) - ) optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999)) lr_scheduler = ( torch.optim.lr_scheduler.StepLR( @@ -435,7 +388,7 @@ if __name__ == "__main__": if args.decay_lr else None ) - params_g = TrainingParams( + params_g = GeneratorParams( model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler ) @@ -443,16 +396,17 @@ if __name__ == "__main__": logger.debug(f"Use distance criterion: {dist_criterion}") training = cGANTraining( - data_loaders=data_loaders, epochs=args.epochs, + dataset=dataset, + batch_size=args.batch_size, device=device, snapshot_dir=args.snapshot_dir, snapshot_epoch=args.snapshot_epoch, - logger=logger, params_generator=params_g, params_discriminator=params_d, loss_func_dist=dist_criterion, weight_criterion_dist=args.dist_loss_weight, weight_criterion_adv=args.adv_loss_weight, + logger=logger, ) training.run() diff --git a/mu_map/training/cgan2.py b/mu_map/training/cgan2.py deleted file mode 100644 index 8880c46..0000000 --- a/mu_map/training/cgan2.py +++ /dev/null @@ -1,412 +0,0 @@ -from logging import Logger -from typing import Optional - -import torch - -from mu_map.dataset.default import MuMapDataset -from mu_map.training.lib import TrainingParams, AbstractTraining -from mu_map.training.loss import WeightedLoss - - -# Establish convention for real and fake labels during training -LABEL_REAL = 1.0 -LABEL_FAKE = 0.0 - - -class DiscriminatorParams(TrainingParams): - """ - Wrap training parameters to always carry the name 'Discriminator'. - """ - def __init__( - self, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], - ): - super().__init__( - name="Discriminator", - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - ) - -class GeneratorParams(TrainingParams): - """ - Wrap training parameters to always carry the name 'Generator'. - """ - def __init__( - self, - model: torch.nn.Module, - optimizer: torch.optim.Optimizer, - lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler], - ): - super().__init__( - name="Generator", - model=model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - ) - - - -class cGANTraining(AbstractTraining): - """ - Implementation of a conditional generative adversarial network training. - """ - def __init__( - self, - epochs: int, - dataset: MuMapDataset, - batch_size: int, - device: torch.device, - snapshot_dir: str, - snapshot_epoch: int, - params_generator: GeneratorParams, - params_discriminator: DiscriminatorParams, - loss_func_dist: WeightedLoss, - weight_criterion_dist: float, - weight_criterion_adv: float, - logger: Optional[Logger] = None, - ): - """ - :param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator - :param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator - :param loss_func_dist: distance loss function for the generator - :param weight_criterion_dist: weight of the distance loss when training the generator - :param weight_criterion_adv: weight of the adversarial loss when training the generator - """ - super().__init__( - epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger - ) - self.training_params.append(params_generator) - self.training_params.append(params_discriminator) - - self.generator = params_generator.model - self.discriminator = params_discriminator.model - - self.optim_g = params_generator.optimizer - self.optim_d = params_discriminator.optimizer - - self.weight_criterion_dist = weight_criterion_dist - self.weight_criterion_adv = weight_criterion_adv - - self.criterion_adv = torch.nn.MSELoss(reduction="mean") - self.criterion_dist = loss_func_dist - - def _after_train_batch(self): - """ - Overwrite calling step on all optimizers as this needs to be done - separately for the generator and discriminator during the training of - a batch. - """ - pass - - def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: - mu_maps_real = mu_maps # rename real mu maps for clarification - # compute fake mu maps with generator - mu_maps_fake = self.generator(recons) - - # prepare inputs for the discriminator - inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1) - inputs_d_real = torch.cat((recons, mu_maps_real), dim=1) - - # prepare labels/targets for the discriminator - labels_fake = torch.full(self.discriminator.get_output_shape(inputs_d_fake.shape), LABEL_FAKE, device=self.device) - labels_real = torch.full(self.discriminator.get_output_shape(inputs_d_real.shape), LABEL_REAL, device=self.device) - - # ======================= Discriminator ===================================== - # compute discriminator loss for fake mu maps - # detach is called so that gradients are not computed for the generator - outputs_d_fake = self.discriminator(inputs_d_fake.detach()) - loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake) - - # compute discriminator loss for real mu maps - outputs_d_real = self.discriminator(inputs_d_real) - loss_d_real = self.criterion_adv(outputs_d_real, labels_real) - - # update discriminator - loss_d = 0.5 * (loss_d_fake + loss_d_real) - loss_d.backward() # compute gradients - self.optim_d.step() - # =========================================================================== - - # ======================= Generator ========================================= - outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach - loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real) - loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real) - loss_g = ( - self.weight_criterion_adv * loss_g_adv - + self.weight_criterion_dist * loss_g_dist - ) - loss_g.backward() - self.optim_g.step() - # =========================================================================== - - return loss_g.item() - - def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float: - mu_maps_fake = self.generator(recons) - loss = torch.nn.functional.l1_loss(mu_maps_fake, mu_maps) - return loss.item() - - -if __name__ == "__main__": - import argparse - import os - import random - import sys - - import numpy as np - - from mu_map.dataset.patches import MuMapPatchDataset - from mu_map.dataset.normalization import ( - MeanNormTransform, - MaxNormTransform, - GaussianNormTransform, - ) - from mu_map.dataset.transform import PadCropTranform, SequenceTransform - from mu_map.logging import add_logging_args, get_logger_by_args - from mu_map.models.unet import UNet - from mu_map.models.discriminator import Discriminator, PatchDiscriminator - - parser = argparse.ArgumentParser( - description="Train a UNet model to predict μ-maps from reconstructed images", - formatter_class=argparse.ArgumentDefaultsHelpFormatter, - ) - - # Model Args - parser.add_argument( - "--features", - type=int, - nargs="+", - default=[64, 128, 256, 512], - help="number of features in the layers of the UNet structure", - ) - - # Dataset Args - parser.add_argument( - "--dataset_dir", - type=str, - default="data/second/", - help="the directory where the dataset for training is found", - ) - parser.add_argument( - "--input_norm", - type=str, - choices=["none", "mean", "max", "gaussian"], - default="mean", - help="type of normalization applied to the reconstructions", - ) - parser.add_argument( - "--patch_size", - type=int, - default=32, - help="the size of patches extracted for each reconstruction", - ) - parser.add_argument( - "--patch_offset", - type=int, - default=20, - help="offset to ignore the border of the image", - ) - parser.add_argument( - "--number_of_patches", - type=int, - default=100, - help="number of patches extracted for each image", - ) - parser.add_argument( - "--no_shuffle", - action="store_true", - help="do not shuffle patches in the dataset", - ) - parser.add_argument( - "--scatter_correction", - action="store_true", - help="use the scatter corrected reconstructions in the dataset", - ) - - # Training Args - parser.add_argument( - "--seed", - type=int, - help="seed used for random number generation", - ) - parser.add_argument( - "--batch_size", - type=int, - default=64, - help="the batch size used for training", - ) - parser.add_argument( - "--output_dir", - type=str, - default="train_data", - help="directory in which results (snapshots and logs) of this training are saved", - ) - parser.add_argument( - "--epochs", - type=int, - default=100, - help="the number of epochs for which the model is trained", - ) - parser.add_argument( - "--device", - type=str, - default="cuda:0" if torch.cuda.is_available() else "cpu", - help="the device (cpu or gpu) with which the training is performed", - ) - parser.add_argument( - "--dist_loss_func", - type=str, - default="l1", - help="define the loss function used as the distance loss of the generator , e.g. 0.75*l2+0.25*gdl", - ) - parser.add_argument( - "--dist_loss_weight", - type=float, - default=100.0, - help="weight for the distance loss of the generator", - ) - parser.add_argument( - "--adv_loss_weight", - type=float, - default=1.0, - help="weight for the Adversarial-Loss of the generator", - ) - parser.add_argument( - "--lr", type=float, default=0.001, help="the initial learning rate for training" - ) - parser.add_argument( - "--decay_lr", - action="store_true", - help="decay the learning rate", - ) - parser.add_argument( - "--lr_decay_factor", - type=float, - default=0.99, - help="decay factor for the learning rate", - ) - parser.add_argument( - "--lr_decay_epoch", - type=int, - default=1, - help="frequency in epochs at which the learning rate is decayed", - ) - parser.add_argument( - "--snapshot_dir", - type=str, - default="snapshots", - help="directory under --output_dir where snapshots are stored", - ) - parser.add_argument( - "--snapshot_epoch", - type=int, - default=10, - help="frequency in epochs at which snapshots are stored", - ) - - # Logging Args - add_logging_args(parser, defaults={"--logfile": "train.log"}) - - args = parser.parse_args() - - if not os.path.exists(args.output_dir): - os.mkdir(args.output_dir) - - args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) - if not os.path.exists(args.snapshot_dir): - os.mkdir(args.snapshot_dir) - else: - if len(os.listdir(args.snapshot_dir)) > 0: - print( - f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!" - ) - print(f" Exit so that data is not accidentally overwritten!") - exit(1) - - args.logfile = os.path.join(args.output_dir, args.logfile) - - logger = get_logger_by_args(args) - logger.info(args) - - device = torch.device(args.device) - - args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) - logger.info(f"Seed: {args.seed}") - random.seed(args.seed) - torch.manual_seed(args.seed) - np.random.seed(args.seed) - - transform_normalization = None - if args.input_norm == "mean": - transform_normalization = MeanNormTransform() - elif args.input_norm == "max": - transform_normalization = MaxNormTransform() - elif args.input_norm == "gaussian": - transform_normalization = GaussianNormTransform() - transform_normalization = SequenceTransform( - [transform_normalization, PadCropTranform(dim=3, size=32)] - ) - - dataset = MuMapPatchDataset( - args.dataset_dir, - patches_per_image=args.number_of_patches, - patch_size=args.patch_size, - patch_offset=args.patch_offset, - shuffle=not args.no_shuffle, - transform_normalization=transform_normalization, - scatter_correction=args.scatter_correction, - logger=logger, - ) - - discriminator = Discriminator(in_channels=2, input_size=args.patch_size) - # discriminator = PatchDiscriminator(in_channels=2) - discriminator = discriminator.to(device) - optimizer = torch.optim.Adam( - discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999) - ) - lr_scheduler = ( - torch.optim.lr_scheduler.StepLR( - optimizer, step_size=args.lr_decay_epoch, gamma=args.lr_decay_factor - ) - if args.decay_lr - else None - ) - params_d = DiscriminatorParams( - model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler - ) - - generator = UNet(in_channels=1, features=args.features) - generator = generator.to(device) - optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999)) - lr_scheduler = ( - torch.optim.lr_scheduler.StepLR( - optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor - ) - if args.decay_lr - else None - ) - params_g = GeneratorParams( - model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler - ) - - dist_criterion = WeightedLoss.from_str(args.dist_loss_func) - logger.debug(f"Use distance criterion: {dist_criterion}") - - training = cGANTraining( - epochs=args.epochs, - dataset=dataset, - batch_size=args.batch_size, - device=device, - snapshot_dir=args.snapshot_dir, - snapshot_epoch=args.snapshot_epoch, - params_generator=params_g, - params_discriminator=params_d, - loss_func_dist=dist_criterion, - weight_criterion_dist=args.dist_loss_weight, - weight_criterion_adv=args.adv_loss_weight, - logger=logger, - ) - training.run() -- GitLab