From 730f362325d3bdb34132c202e6e43db361ef5440 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Fri, 6 Jan 2023 11:59:18 +0100
Subject: [PATCH] replace old cgan with new cgan

---
 mu_map/training/cgan.py  | 320 +++++++++++++-----------------
 mu_map/training/cgan2.py | 412 ---------------------------------------
 2 files changed, 137 insertions(+), 595 deletions(-)
 delete mode 100644 mu_map/training/cgan2.py

diff --git a/mu_map/training/cgan.py b/mu_map/training/cgan.py
index c2c9f16..8880c46 100644
--- a/mu_map/training/cgan.py
+++ b/mu_map/training/cgan.py
@@ -1,51 +1,91 @@
-from dataclasses import dataclass
-import os
-from typing import Dict, Optional
-import sys
+from logging import Logger
+from typing import Optional
 
 import torch
-from torch import Tensor
 
+from mu_map.dataset.default import MuMapDataset
+from mu_map.training.lib import TrainingParams, AbstractTraining
 from mu_map.training.loss import WeightedLoss
-from mu_map.logging import get_logger
+
 
 # Establish convention for real and fake labels during training
 LABEL_REAL = 1.0
 LABEL_FAKE = 0.0
 
 
-@dataclass
-class TrainingParams:
-    model: torch.nn.Module
-    optimizer: torch.optim.Optimizer
-    lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
+class DiscriminatorParams(TrainingParams):
+    """
+    Wrap training parameters to always carry the name 'Discriminator'.
+    """
+    def __init__(
+        self,
+        model: torch.nn.Module,
+        optimizer: torch.optim.Optimizer,
+        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
+    ):
+        super().__init__(
+            name="Discriminator",
+            model=model,
+            optimizer=optimizer,
+            lr_scheduler=lr_scheduler,
+        )
+
+class GeneratorParams(TrainingParams):
+    """
+    Wrap training parameters to always carry the name 'Generator'.
+    """
+    def __init__(
+        self,
+        model: torch.nn.Module,
+        optimizer: torch.optim.Optimizer,
+        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
+    ):
+        super().__init__(
+            name="Generator",
+            model=model,
+            optimizer=optimizer,
+            lr_scheduler=lr_scheduler,
+        )
+
 
 
-class cGANTraining:
+class cGANTraining(AbstractTraining):
+    """
+    Implementation of a conditional generative adversarial network training.
+    """
     def __init__(
         self,
-        data_loaders: Dict[str, torch.utils.data.DataLoader],
         epochs: int,
+        dataset: MuMapDataset,
+        batch_size: int,
         device: torch.device,
         snapshot_dir: str,
         snapshot_epoch: int,
-        params_generator: torch.nn.Module,
-        params_discriminator: torch.nn.Module,
+        params_generator: GeneratorParams,
+        params_discriminator: DiscriminatorParams,
         loss_func_dist: WeightedLoss,
         weight_criterion_dist: float,
         weight_criterion_adv: float,
-        logger=None,
+        logger: Optional[Logger] = None,
     ):
-        self.data_loaders = data_loaders
-        self.epochs = epochs
-        self.device = device
+        """
+        :param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
+        :param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
+        :param loss_func_dist: distance loss function for the generator
+        :param weight_criterion_dist: weight of the distance loss when training the generator
+        :param weight_criterion_adv: weight of the adversarial loss when training the generator
+        """
+        super().__init__(
+            epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger
+        )
+        self.training_params.append(params_generator)
+        self.training_params.append(params_discriminator)
 
-        self.snapshot_dir = snapshot_dir
-        self.snapshot_epoch = snapshot_epoch
-        self.logger = logger if logger is not None else get_logger(name=cGANTraining.__name__)
+        self.generator = params_generator.model
+        self.discriminator = params_discriminator.model
 
-        self.params_g = params_generator
-        self.params_d = params_discriminator
+        self.optim_g = params_generator.optimizer
+        self.optim_d = params_discriminator.optimizer
 
         self.weight_criterion_dist = weight_criterion_dist
         self.weight_criterion_adv = weight_criterion_adv
@@ -53,134 +93,66 @@ class cGANTraining:
         self.criterion_adv = torch.nn.MSELoss(reduction="mean")
         self.criterion_dist = loss_func_dist
 
-    def run(self):
-        loss_val_min = sys.maxsize
-        for epoch in range(1, self.epochs + 1):
-            str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
-            self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
-
-            self._train_epoch()
-            loss_train = self._eval_epoch("train")
-            self.logger.info(
-                f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
-            )
-            loss_val = self._eval_epoch("validation")
-            self.logger.info(
-                f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
-            )
-
-            if loss_val < loss_val_min:
-                loss_val_min = loss_val
-                self.logger.info(
-                    f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
-                )
-                self.store_snapshot("val_min")
-            if epoch % self.snapshot_epoch == 0:
-                self._store_snapshot(epoch)
-
-            if self.params_d.lr_scheduler is not None:
-                self.params_d.lr_scheduler.step()
-            if self.params_g.lr_scheduler is not None:
-                self.params_g.lr_scheduler.step()
-        return loss_val_min
-
-    def _train_epoch(self):
-        # setup training mode
-        torch.set_grad_enabled(True)
-        self.params_d.model.train()
-        self.params_g.model.train()
-
-        data_loader = self.data_loaders["train"]
-        for i, (recons, mu_maps_real) in enumerate(data_loader):
-            print(
-                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
-                end="\r",
-            )
-            batch_size = recons.shape[0]
-
-            recons = recons.to(self.device)
-            mu_maps_real = mu_maps_real.to(self.device)
-
-            self.params_d.optimizer.zero_grad()
-            self.params_g.optimizer.zero_grad()
-
-            # compute fake mu maps with generator
-            mu_maps_fake = self.params_g.model(recons)
-
-            # compute discriminator loss for fake mu maps
-            inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
-            outputs_d_fake = self.params_d.model(
-                inputs_d_fake.detach()
-            )  # note the detach, so that gradients are not computed for the generator
-            labels_fake = torch.full(
-                (outputs_d_fake.shape), LABEL_FAKE, device=self.device
-            )
-            loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake)
-
-            # compute discriminator loss for real mu maps
-            inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
-            outputs_d_real = self.params_d.model(
-                inputs_d_real
-            )  # note the detach, so that gradients are not computed for the generator
-            labels_real = torch.full(
-                (outputs_d_fake.shape), LABEL_REAL, device=self.device
-            )
-            loss_d_real = self.criterion_adv(outputs_d_real, labels_real)
-
-            # update discriminator
-            loss_d = 0.5 * (loss_d_fake + loss_d_real)
-            loss_d.backward()  # compute gradients
-            self.params_d.optimizer.step()
-
-            # update generator
-            inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
-            outputs_d_fake = self.params_d.model(inputs_d_fake)
-            loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
-            loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
-            loss_g = (
-                self.weight_criterion_adv * loss_g_adv
-                + self.weight_criterion_dist * loss_g_dist
-            )
-            loss_g.backward()
-            self.params_g.optimizer.step()
-
-    def _eval_epoch(self, split_name):
-        # setup evaluation mode
-        torch.set_grad_enabled(False)
-        self.params_d.model = self.params_d.model.eval()
-        self.params_g.model = self.params_g.model.eval()
-        data_loader = self.data_loaders[split_name]
-
-        loss = 0.0
-        updates = 0
-        for i, (recons, mu_maps) in enumerate(data_loader):
-            print(
-                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
-                end="\r",
-            )
-            recons = recons.to(self.device)
-            mu_maps = mu_maps.to(self.device)
-
-            outputs = self.params_g.model(recons)
-
-            loss += torch.nn.functional.l1_loss(outputs, mu_maps)
-            updates += 1
-        return loss / updates
+    def _after_train_batch(self):
+        """
+        Overwrite calling step on all optimizers as this needs to be done
+        separately for the generator and discriminator during the training of
+        a batch.
+        """
+        pass
+
+    def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
+        mu_maps_real = mu_maps # rename real mu maps for clarification
+        # compute fake mu maps with generator
+        mu_maps_fake = self.generator(recons)
+
+        # prepare inputs for the discriminator
+        inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
+        inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
+
+        # prepare labels/targets for the discriminator
+        labels_fake = torch.full(self.discriminator.get_output_shape(inputs_d_fake.shape), LABEL_FAKE, device=self.device)
+        labels_real = torch.full(self.discriminator.get_output_shape(inputs_d_real.shape), LABEL_REAL, device=self.device)
+
+        # ======================= Discriminator =====================================
+        # compute discriminator loss for fake mu maps
+        # detach is called so that gradients are not computed for the generator
+        outputs_d_fake = self.discriminator(inputs_d_fake.detach())
+        loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake)
+
+        # compute discriminator loss for real mu maps
+        outputs_d_real = self.discriminator(inputs_d_real)
+        loss_d_real = self.criterion_adv(outputs_d_real, labels_real)
+
+        # update discriminator
+        loss_d = 0.5 * (loss_d_fake + loss_d_real)
+        loss_d.backward()  # compute gradients
+        self.optim_d.step()
+        # ===========================================================================
+
+        # ======================= Generator =========================================
+        outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach
+        loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
+        loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
+        loss_g = (
+            self.weight_criterion_adv * loss_g_adv
+            + self.weight_criterion_dist * loss_g_dist
+        )
+        loss_g.backward()
+        self.optim_g.step()
+        # ===========================================================================
 
-    def _store_snapshot(self, epoch):
-        prefix = f"{epoch:0{len(str(self.epochs))}d}"
-        self.store_snapshot(prefix)
+        return loss_g.item()
 
-    def store_snapshot(self, prefix: str):
-        snapshot_file_d = os.path.join(self.snapshot_dir, f"{prefix}_discriminator.pth")
-        snapshot_file_g = os.path.join(self.snapshot_dir, f"{prefix}_generator.pth")
-        self.logger.debug(f"Store snapshots at {snapshot_file_d} and {snapshot_file_g}")
-        torch.save(self.params_d.model.state_dict(), snapshot_file_d)
-        torch.save(self.params_g.model.state_dict(), snapshot_file_g)
+    def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
+        mu_maps_fake = self.generator(recons)
+        loss = torch.nn.functional.l1_loss(mu_maps_fake, mu_maps)
+        return loss.item()
 
 
 if __name__ == "__main__":
     import argparse
+    import os
     import random
     import sys
 
@@ -198,7 +170,7 @@ if __name__ == "__main__":
     from mu_map.models.discriminator import Discriminator, PatchDiscriminator
 
     parser = argparse.ArgumentParser(
-        description="Train a UNet model to predict μ-maps from reconstructed scatter images",
+        description="Train a UNet model to predict μ-maps from reconstructed images",
         formatter_class=argparse.ArgumentDefaultsHelpFormatter,
     )
 
@@ -334,11 +306,6 @@ if __name__ == "__main__":
         default=10,
         help="frequency in epochs at which snapshots are stored",
     )
-    parser.add_argument(
-        "--generator_weights",
-        type=str,
-        help="use pre-trained weights for the generator",
-    )
 
     # Logging Args
     add_logging_args(parser, defaults={"--logfile": "train.log"})
@@ -361,10 +328,11 @@ if __name__ == "__main__":
 
     args.logfile = os.path.join(args.output_dir, args.logfile)
 
-    device = torch.device(args.device)
     logger = get_logger_by_args(args)
     logger.info(args)
 
+    device = torch.device(args.device)
+
     args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
     logger.info(f"Seed: {args.seed}")
     random.seed(args.seed)
@@ -382,29 +350,19 @@ if __name__ == "__main__":
         [transform_normalization, PadCropTranform(dim=3, size=32)]
     )
 
-    data_loaders = {}
-    for split in ["train", "validation"]:
-        dataset = MuMapPatchDataset(
-            args.dataset_dir,
-            patches_per_image=args.number_of_patches,
-            patch_size=args.patch_size,
-            patch_offset=args.patch_offset,
-            shuffle=not args.no_shuffle,
-            split_name=split,
-            transform_normalization=transform_normalization,
-            scatter_correction=args.scatter_correction,
-            logger=logger,
-        )
-        data_loader = torch.utils.data.DataLoader(
-            dataset=dataset,
-            batch_size=args.batch_size,
-            shuffle=True,
-            pin_memory=True,
-            num_workers=1,
-        )
-        data_loaders[split] = data_loader
+    dataset = MuMapPatchDataset(
+        args.dataset_dir,
+        patches_per_image=args.number_of_patches,
+        patch_size=args.patch_size,
+        patch_offset=args.patch_offset,
+        shuffle=not args.no_shuffle,
+        transform_normalization=transform_normalization,
+        scatter_correction=args.scatter_correction,
+        logger=logger,
+    )
 
     discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
+    # discriminator = PatchDiscriminator(in_channels=2)
     discriminator = discriminator.to(device)
     optimizer = torch.optim.Adam(
         discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999)
@@ -416,17 +374,12 @@ if __name__ == "__main__":
         if args.decay_lr
         else None
     )
-    params_d = TrainingParams(
+    params_d = DiscriminatorParams(
         model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
     )
 
     generator = UNet(in_channels=1, features=args.features)
     generator = generator.to(device)
-    if args.generator_weights:
-        logger.debug(f"Load generator weights from {args.generator_weights}")
-        generator.load_state_dict(
-            torch.load(args.generator_weights, map_location=device)
-        )
     optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
     lr_scheduler = (
         torch.optim.lr_scheduler.StepLR(
@@ -435,7 +388,7 @@ if __name__ == "__main__":
         if args.decay_lr
         else None
     )
-    params_g = TrainingParams(
+    params_g = GeneratorParams(
         model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
     )
 
@@ -443,16 +396,17 @@ if __name__ == "__main__":
     logger.debug(f"Use distance criterion: {dist_criterion}")
 
     training = cGANTraining(
-        data_loaders=data_loaders,
         epochs=args.epochs,
+        dataset=dataset,
+        batch_size=args.batch_size,
         device=device,
         snapshot_dir=args.snapshot_dir,
         snapshot_epoch=args.snapshot_epoch,
-        logger=logger,
         params_generator=params_g,
         params_discriminator=params_d,
         loss_func_dist=dist_criterion,
         weight_criterion_dist=args.dist_loss_weight,
         weight_criterion_adv=args.adv_loss_weight,
+        logger=logger,
     )
     training.run()
diff --git a/mu_map/training/cgan2.py b/mu_map/training/cgan2.py
deleted file mode 100644
index 8880c46..0000000
--- a/mu_map/training/cgan2.py
+++ /dev/null
@@ -1,412 +0,0 @@
-from logging import Logger
-from typing import Optional
-
-import torch
-
-from mu_map.dataset.default import MuMapDataset
-from mu_map.training.lib import TrainingParams, AbstractTraining
-from mu_map.training.loss import WeightedLoss
-
-
-# Establish convention for real and fake labels during training
-LABEL_REAL = 1.0
-LABEL_FAKE = 0.0
-
-
-class DiscriminatorParams(TrainingParams):
-    """
-    Wrap training parameters to always carry the name 'Discriminator'.
-    """
-    def __init__(
-        self,
-        model: torch.nn.Module,
-        optimizer: torch.optim.Optimizer,
-        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
-    ):
-        super().__init__(
-            name="Discriminator",
-            model=model,
-            optimizer=optimizer,
-            lr_scheduler=lr_scheduler,
-        )
-
-class GeneratorParams(TrainingParams):
-    """
-    Wrap training parameters to always carry the name 'Generator'.
-    """
-    def __init__(
-        self,
-        model: torch.nn.Module,
-        optimizer: torch.optim.Optimizer,
-        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
-    ):
-        super().__init__(
-            name="Generator",
-            model=model,
-            optimizer=optimizer,
-            lr_scheduler=lr_scheduler,
-        )
-
-
-
-class cGANTraining(AbstractTraining):
-    """
-    Implementation of a conditional generative adversarial network training.
-    """
-    def __init__(
-        self,
-        epochs: int,
-        dataset: MuMapDataset,
-        batch_size: int,
-        device: torch.device,
-        snapshot_dir: str,
-        snapshot_epoch: int,
-        params_generator: GeneratorParams,
-        params_discriminator: DiscriminatorParams,
-        loss_func_dist: WeightedLoss,
-        weight_criterion_dist: float,
-        weight_criterion_adv: float,
-        logger: Optional[Logger] = None,
-    ):
-        """
-        :param params_generator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the generator
-        :param params_discriminator: training parameters containing a model an according optimizer and optionally a learning rate scheduler for the discriminator
-        :param loss_func_dist: distance loss function for the generator
-        :param weight_criterion_dist: weight of the distance loss when training the generator
-        :param weight_criterion_adv: weight of the adversarial loss when training the generator
-        """
-        super().__init__(
-            epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger
-        )
-        self.training_params.append(params_generator)
-        self.training_params.append(params_discriminator)
-
-        self.generator = params_generator.model
-        self.discriminator = params_discriminator.model
-
-        self.optim_g = params_generator.optimizer
-        self.optim_d = params_discriminator.optimizer
-
-        self.weight_criterion_dist = weight_criterion_dist
-        self.weight_criterion_adv = weight_criterion_adv
-
-        self.criterion_adv = torch.nn.MSELoss(reduction="mean")
-        self.criterion_dist = loss_func_dist
-
-    def _after_train_batch(self):
-        """
-        Overwrite calling step on all optimizers as this needs to be done
-        separately for the generator and discriminator during the training of
-        a batch.
-        """
-        pass
-
-    def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
-        mu_maps_real = mu_maps # rename real mu maps for clarification
-        # compute fake mu maps with generator
-        mu_maps_fake = self.generator(recons)
-
-        # prepare inputs for the discriminator
-        inputs_d_fake = torch.cat((recons, mu_maps_fake), dim=1)
-        inputs_d_real = torch.cat((recons, mu_maps_real), dim=1)
-
-        # prepare labels/targets for the discriminator
-        labels_fake = torch.full(self.discriminator.get_output_shape(inputs_d_fake.shape), LABEL_FAKE, device=self.device)
-        labels_real = torch.full(self.discriminator.get_output_shape(inputs_d_real.shape), LABEL_REAL, device=self.device)
-
-        # ======================= Discriminator =====================================
-        # compute discriminator loss for fake mu maps
-        # detach is called so that gradients are not computed for the generator
-        outputs_d_fake = self.discriminator(inputs_d_fake.detach())
-        loss_d_fake = self.criterion_adv(outputs_d_fake, labels_fake)
-
-        # compute discriminator loss for real mu maps
-        outputs_d_real = self.discriminator(inputs_d_real)
-        loss_d_real = self.criterion_adv(outputs_d_real, labels_real)
-
-        # update discriminator
-        loss_d = 0.5 * (loss_d_fake + loss_d_real)
-        loss_d.backward()  # compute gradients
-        self.optim_d.step()
-        # ===========================================================================
-
-        # ======================= Generator =========================================
-        outputs_d_fake = self.discriminator(inputs_d_fake) # this time no detach
-        loss_g_adv = self.criterion_adv(outputs_d_fake, labels_real)
-        loss_g_dist = self.criterion_dist(mu_maps_fake, mu_maps_real)
-        loss_g = (
-            self.weight_criterion_adv * loss_g_adv
-            + self.weight_criterion_dist * loss_g_dist
-        )
-        loss_g.backward()
-        self.optim_g.step()
-        # ===========================================================================
-
-        return loss_g.item()
-
-    def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
-        mu_maps_fake = self.generator(recons)
-        loss = torch.nn.functional.l1_loss(mu_maps_fake, mu_maps)
-        return loss.item()
-
-
-if __name__ == "__main__":
-    import argparse
-    import os
-    import random
-    import sys
-
-    import numpy as np
-
-    from mu_map.dataset.patches import MuMapPatchDataset
-    from mu_map.dataset.normalization import (
-        MeanNormTransform,
-        MaxNormTransform,
-        GaussianNormTransform,
-    )
-    from mu_map.dataset.transform import PadCropTranform, SequenceTransform
-    from mu_map.logging import add_logging_args, get_logger_by_args
-    from mu_map.models.unet import UNet
-    from mu_map.models.discriminator import Discriminator, PatchDiscriminator
-
-    parser = argparse.ArgumentParser(
-        description="Train a UNet model to predict μ-maps from reconstructed images",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Model Args
-    parser.add_argument(
-        "--features",
-        type=int,
-        nargs="+",
-        default=[64, 128, 256, 512],
-        help="number of features in the layers of the UNet structure",
-    )
-
-    # Dataset Args
-    parser.add_argument(
-        "--dataset_dir",
-        type=str,
-        default="data/second/",
-        help="the directory where the dataset for training is found",
-    )
-    parser.add_argument(
-        "--input_norm",
-        type=str,
-        choices=["none", "mean", "max", "gaussian"],
-        default="mean",
-        help="type of normalization applied to the reconstructions",
-    )
-    parser.add_argument(
-        "--patch_size",
-        type=int,
-        default=32,
-        help="the size of patches extracted for each reconstruction",
-    )
-    parser.add_argument(
-        "--patch_offset",
-        type=int,
-        default=20,
-        help="offset to ignore the border of the image",
-    )
-    parser.add_argument(
-        "--number_of_patches",
-        type=int,
-        default=100,
-        help="number of patches extracted for each image",
-    )
-    parser.add_argument(
-        "--no_shuffle",
-        action="store_true",
-        help="do not shuffle patches in the dataset",
-    )
-    parser.add_argument(
-        "--scatter_correction",
-        action="store_true",
-        help="use the scatter corrected reconstructions in the dataset",
-    )
-
-    # Training Args
-    parser.add_argument(
-        "--seed",
-        type=int,
-        help="seed used for random number generation",
-    )
-    parser.add_argument(
-        "--batch_size",
-        type=int,
-        default=64,
-        help="the batch size used for training",
-    )
-    parser.add_argument(
-        "--output_dir",
-        type=str,
-        default="train_data",
-        help="directory in which results (snapshots and logs) of this training are saved",
-    )
-    parser.add_argument(
-        "--epochs",
-        type=int,
-        default=100,
-        help="the number of epochs for which the model is trained",
-    )
-    parser.add_argument(
-        "--device",
-        type=str,
-        default="cuda:0" if torch.cuda.is_available() else "cpu",
-        help="the device (cpu or gpu) with which the training is performed",
-    )
-    parser.add_argument(
-        "--dist_loss_func",
-        type=str,
-        default="l1",
-        help="define the loss function used as the distance loss of the generator , e.g. 0.75*l2+0.25*gdl",
-    )
-    parser.add_argument(
-        "--dist_loss_weight",
-        type=float,
-        default=100.0,
-        help="weight for the distance loss of the generator",
-    )
-    parser.add_argument(
-        "--adv_loss_weight",
-        type=float,
-        default=1.0,
-        help="weight for the Adversarial-Loss of the generator",
-    )
-    parser.add_argument(
-        "--lr", type=float, default=0.001, help="the initial learning rate for training"
-    )
-    parser.add_argument(
-        "--decay_lr",
-        action="store_true",
-        help="decay the learning rate",
-    )
-    parser.add_argument(
-        "--lr_decay_factor",
-        type=float,
-        default=0.99,
-        help="decay factor for the learning rate",
-    )
-    parser.add_argument(
-        "--lr_decay_epoch",
-        type=int,
-        default=1,
-        help="frequency in epochs at which the learning rate is decayed",
-    )
-    parser.add_argument(
-        "--snapshot_dir",
-        type=str,
-        default="snapshots",
-        help="directory under --output_dir where snapshots are stored",
-    )
-    parser.add_argument(
-        "--snapshot_epoch",
-        type=int,
-        default=10,
-        help="frequency in epochs at which snapshots are stored",
-    )
-
-    # Logging Args
-    add_logging_args(parser, defaults={"--logfile": "train.log"})
-
-    args = parser.parse_args()
-
-    if not os.path.exists(args.output_dir):
-        os.mkdir(args.output_dir)
-
-    args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
-    if not os.path.exists(args.snapshot_dir):
-        os.mkdir(args.snapshot_dir)
-    else:
-        if len(os.listdir(args.snapshot_dir)) > 0:
-            print(
-                f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
-            )
-            print(f"           Exit so that data is not accidentally overwritten!")
-            exit(1)
-
-    args.logfile = os.path.join(args.output_dir, args.logfile)
-
-    logger = get_logger_by_args(args)
-    logger.info(args)
-
-    device = torch.device(args.device)
-
-    args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
-    logger.info(f"Seed: {args.seed}")
-    random.seed(args.seed)
-    torch.manual_seed(args.seed)
-    np.random.seed(args.seed)
-
-    transform_normalization = None
-    if args.input_norm == "mean":
-        transform_normalization = MeanNormTransform()
-    elif args.input_norm == "max":
-        transform_normalization = MaxNormTransform()
-    elif args.input_norm == "gaussian":
-        transform_normalization = GaussianNormTransform()
-    transform_normalization = SequenceTransform(
-        [transform_normalization, PadCropTranform(dim=3, size=32)]
-    )
-
-    dataset = MuMapPatchDataset(
-        args.dataset_dir,
-        patches_per_image=args.number_of_patches,
-        patch_size=args.patch_size,
-        patch_offset=args.patch_offset,
-        shuffle=not args.no_shuffle,
-        transform_normalization=transform_normalization,
-        scatter_correction=args.scatter_correction,
-        logger=logger,
-    )
-
-    discriminator = Discriminator(in_channels=2, input_size=args.patch_size)
-    # discriminator = PatchDiscriminator(in_channels=2)
-    discriminator = discriminator.to(device)
-    optimizer = torch.optim.Adam(
-        discriminator.parameters(), lr=args.lr, betas=(0.5, 0.999)
-    )
-    lr_scheduler = (
-        torch.optim.lr_scheduler.StepLR(
-            optimizer, step_size=args.lr_decay_epoch, gamma=args.lr_decay_factor
-        )
-        if args.decay_lr
-        else None
-    )
-    params_d = DiscriminatorParams(
-        model=discriminator, optimizer=optimizer, lr_scheduler=lr_scheduler
-    )
-
-    generator = UNet(in_channels=1, features=args.features)
-    generator = generator.to(device)
-    optimizer = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(0.5, 0.999))
-    lr_scheduler = (
-        torch.optim.lr_scheduler.StepLR(
-            optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
-        )
-        if args.decay_lr
-        else None
-    )
-    params_g = GeneratorParams(
-        model=generator, optimizer=optimizer, lr_scheduler=lr_scheduler
-    )
-
-    dist_criterion = WeightedLoss.from_str(args.dist_loss_func)
-    logger.debug(f"Use distance criterion: {dist_criterion}")
-
-    training = cGANTraining(
-        epochs=args.epochs,
-        dataset=dataset,
-        batch_size=args.batch_size,
-        device=device,
-        snapshot_dir=args.snapshot_dir,
-        snapshot_epoch=args.snapshot_epoch,
-        params_generator=params_g,
-        params_discriminator=params_d,
-        loss_func_dist=dist_criterion,
-        weight_criterion_dist=args.dist_loss_weight,
-        weight_criterion_adv=args.adv_loss_weight,
-        logger=logger,
-    )
-    training.run()
-- 
GitLab