From b36a84552c67555a2841329c1b39d8f45757e413 Mon Sep 17 00:00:00 2001
From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de>
Date: Tue, 27 Sep 2022 14:43:41 +0200
Subject: [PATCH] make default training configurable via argparse

---
 mu_map/training/default.py | 203 +++++++++++++++++++++++++++++--------
 1 file changed, 163 insertions(+), 40 deletions(-)

diff --git a/mu_map/training/default.py b/mu_map/training/default.py
index 7f41f43..adf89ce 100644
--- a/mu_map/training/default.py
+++ b/mu_map/training/default.py
@@ -1,59 +1,89 @@
 import os
+from typing import Dict
 
 import torch
 
-class Training():
-
-    def __init__(self, model, data_loaders, epochs, logger):
+from mu_map.logging import get_logger
+
+
+class Training:
+    def __init__(
+        self,
+        model: torch.nn.Module,
+        data_loaders: Dict[str, torch.utils.data.DataLoader],
+        epochs: int,
+        device: torch.device,
+        lr: float,
+        lr_decay_factor: float,
+        lr_decay_epoch: int,
+        snapshot_dir: str,
+        snapshot_epoch: int,
+        logger=None,
+    ):
         self.model = model
         self.data_loaders = data_loaders
         self.epochs = epochs
-        self.device = torch.device("cpu")
-        self.snapshot_dir = "tmp"
-        self.snapshot_epoch = 5
-        self.loss_func = torch.nn.MSELoss()
+        self.device = device
 
-        # self.lr = 1e-3
-        # self.lr_decay_factor = 0.99
-        self.lr = 0.1
-        self.lr_decay_factor = 0.5
-        self.lr_decay_epoch = 1
+        self.lr = lr
+        self.lr_decay_factor = lr_decay_factor
+        self.lr_decay_epoch = lr_decay_epoch
 
-        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
-        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor)
+        self.snapshot_dir = snapshot_dir
+        self.snapshot_epoch = snapshot_epoch
 
-        self.logger = logger
+        self.logger = logger if logger is not None else get_logger()
+
+        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
+        self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
+            self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
+        )
+        self.loss_func = torch.nn.MSELoss(reduction="mean")
 
 
     def run(self):
         for epoch in range(1, self.epochs + 1):
-            logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...")
+            logger.debug(
+                f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
+            )
             self._run_epoch(self.data_loaders["train"], phase="train")
 
-            loss_training = self._run_epoch(self.data_loaders["train"], phase="eval")
-            logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}")
-            loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval")
-            logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}")
+            loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
+            logger.info(
+                f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}"
+            )
+            loss_validation = self._run_epoch(
+                self.data_loaders["validation"], phase="val"
+            )
+            logger.info(
+                f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}"
+            )
 
             # ToDo: log outputs and time
             _previous = self.lr_scheduler.get_last_lr()[0]
             self.lr_scheduler.step()
-            logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}")
+            logger.debug(
+                f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}"
+            )
 
             if epoch % self.snapshot_epoch == 0:
                 self.store_snapshot(epoch)
 
-            logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}")
-
-            
+            logger.debug(
+                f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
+            )
 
     def _run_epoch(self, data_loader, phase):
         logger.debug(f"Run epoch in phase {phase}")
         self.model.train() if phase == "train" else self.model.eval()
 
         epoch_loss = 0
+        loss_updates = 0
         for i, (inputs, labels) in enumerate(data_loader):
-            print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r")
+            print(
+                f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
+                end="\r",
+            )
             inputs = inputs.to(self.device)
             labels = labels.to(self.device)
 
@@ -66,9 +96,9 @@ class Training():
                     loss.backward()
                     self.optimizer.step()
 
-            epoch_loss += loss.item() / inputs.shape[0]
-        return epoch_loss
-
+            epoch_loss += loss.item()
+            loss_updates += 1
+        return epoch_loss / loss_updates
 
     def store_snapshot(self, epoch):
         snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
@@ -78,20 +108,113 @@ class Training():
 
 
 if __name__ == "__main__":
-    from mu_map.data.mock import MuMapMockDataset
-    from mu_map.logging import get_logger
-    from mu_map.models.unet import UNet
+    import argparse
 
-    logger = get_logger(logfile="train.log", loglevel="DEBUG")
+    from mu_map.dataset.mock import MuMapMockDataset
+    from mu_map.logging import add_logging_args, get_logger_by_args
+    from mu_map.models.unet import UNet
 
-    model = UNet(in_channels=1, features=[8, 16])
-    print(model)
-    dataset = MuMapMockDataset()
-    data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
-    data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
+    parser = argparse.ArgumentParser(
+        description="Train a UNet model to predict μ-maps from reconstructed scatter images",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+
+    # Model Args
+    parser.add_argument(
+        "--features",
+        type=int,
+        nargs="+",
+        default=[8, 16],
+        help="number of features in the layers of the UNet structure",
+    )
+
+    # Dataset Args
+    # parser.add_argument("--features", type=int, nargs="+", default=[8, 16], help="number of features in the layers of the UNet structure")
+
+    # Training Args
+    parser.add_argument(
+        "--output_dir",
+        type=str,
+        default="train_data",
+        help="directory in which results (snapshots and logs) of this training are saved",
+    )
+    parser.add_argument(
+        "--epochs",
+        type=int,
+        default=10,
+        help="the number of epochs for which the model is trained",
+    )
+    parser.add_argument(
+        "--device",
+        type=str,
+        default="cuda:0" if torch.cuda.is_available() else "cpu",
+        help="the device (cpu or gpu) with which the training is performed",
+    )
+    parser.add_argument(
+        "--lr", type=float, default=0.1, help="the initial learning rate for training"
+    )
+    parser.add_argument(
+        "--lr_decay_factor",
+        type=float,
+        default=0.99,
+        help="decay factor for the learning rate",
+    )
+    parser.add_argument(
+        "--lr_decay_epoch",
+        type=int,
+        default=1,
+        help="frequency in epochs at which the learning rate is decayed",
+    )
+    parser.add_argument(
+        "--snapshot_dir",
+        type=str,
+        default="snapshots",
+        help="directory under --output_dir where snapshots are stored",
+    )
+    parser.add_argument(
+        "--snapshot_epoch",
+        type=int,
+        default=10,
+        help="frequency in epochs at which snapshots are stored",
+    )
+
+    # Logging Args
+    add_logging_args(parser, defaults={"--logfile": "train.log"})
+
+    args = parser.parse_args()
+
+    if not os.path.exists(args.output_dir):
+        os.mkdir(args.output_dir)
+
+    args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
+    if not os.path.exists(args.snapshot_dir):
+        os.mkdir(args.snapshot_dir)
+
+    args.logfile = os.path.join(args.output_dir, args.logfile)
+
+    device = torch.device(args.device)
+    logger = get_logger_by_args(args)
+
+    model = UNet(in_channels=1, features=args.features)
+    dataset = MuMapMockDataset(logger=logger)
+    data_loader_train = torch.utils.data.DataLoader(
+        dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
+    )
+    data_loader_val = torch.utils.data.DataLoader(
+        dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
+    )
     data_loaders = {"train": data_loader_train, "validation": data_loader_val}
 
-    training = Training(model, data_loaders, 10, logger)
+    training = Training(
+        model=model,
+        data_loaders=data_loaders,
+        epochs=args.epochs,
+        device=device,
+        lr=args.lr,
+        lr_decay_factor=args.lr_decay_factor,
+        lr_decay_epoch=args.lr_decay_epoch,
+        snapshot_dir=args.snapshot_dir,
+        snapshot_epoch=args.snapshot_epoch,
+        logger=logger,
+    )
     training.run()
-
-
-- 
GitLab