From b36a84552c67555a2841329c1b39d8f45757e413 Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Tue, 27 Sep 2022 14:43:41 +0200 Subject: [PATCH] make default training configurable via argparse --- mu_map/training/default.py | 203 +++++++++++++++++++++++++++++-------- 1 file changed, 163 insertions(+), 40 deletions(-) diff --git a/mu_map/training/default.py b/mu_map/training/default.py index 7f41f43..adf89ce 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -1,59 +1,89 @@ import os +from typing import Dict import torch -class Training(): - - def __init__(self, model, data_loaders, epochs, logger): +from mu_map.logging import get_logger + + +class Training: + def __init__( + self, + model: torch.nn.Module, + data_loaders: Dict[str, torch.utils.data.DataLoader], + epochs: int, + device: torch.device, + lr: float, + lr_decay_factor: float, + lr_decay_epoch: int, + snapshot_dir: str, + snapshot_epoch: int, + logger=None, + ): self.model = model self.data_loaders = data_loaders self.epochs = epochs - self.device = torch.device("cpu") - self.snapshot_dir = "tmp" - self.snapshot_epoch = 5 - self.loss_func = torch.nn.MSELoss() + self.device = device - # self.lr = 1e-3 - # self.lr_decay_factor = 0.99 - self.lr = 0.1 - self.lr_decay_factor = 0.5 - self.lr_decay_epoch = 1 + self.lr = lr + self.lr_decay_factor = lr_decay_factor + self.lr_decay_epoch = lr_decay_epoch - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) - self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor) + self.snapshot_dir = snapshot_dir + self.snapshot_epoch = snapshot_epoch - self.logger = logger + self.logger = logger if logger is not None else get_logger() + + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + self.lr_scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor + ) + self.loss_func = torch.nn.MSELoss(reduction="mean") def run(self): for epoch in range(1, self.epochs + 1): - logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...") + logger.debug( + f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..." + ) self._run_epoch(self.data_loaders["train"], phase="train") - loss_training = self._run_epoch(self.data_loaders["train"], phase="eval") - logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}") - loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval") - logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}") + loss_training = self._run_epoch(self.data_loaders["train"], phase="val") + logger.info( + f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}" + ) + loss_validation = self._run_epoch( + self.data_loaders["validation"], phase="val" + ) + logger.info( + f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}" + ) # ToDo: log outputs and time _previous = self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.step() - logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}") + logger.debug( + f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}" + ) if epoch % self.snapshot_epoch == 0: self.store_snapshot(epoch) - logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}") - - + logger.debug( + f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}" + ) def _run_epoch(self, data_loader, phase): logger.debug(f"Run epoch in phase {phase}") self.model.train() if phase == "train" else self.model.eval() epoch_loss = 0 + loss_updates = 0 for i, (inputs, labels) in enumerate(data_loader): - print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r") + print( + f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", + end="\r", + ) inputs = inputs.to(self.device) labels = labels.to(self.device) @@ -66,9 +96,9 @@ class Training(): loss.backward() self.optimizer.step() - epoch_loss += loss.item() / inputs.shape[0] - return epoch_loss - + epoch_loss += loss.item() + loss_updates += 1 + return epoch_loss / loss_updates def store_snapshot(self, epoch): snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth" @@ -78,20 +108,113 @@ class Training(): if __name__ == "__main__": - from mu_map.data.mock import MuMapMockDataset - from mu_map.logging import get_logger - from mu_map.models.unet import UNet + import argparse - logger = get_logger(logfile="train.log", loglevel="DEBUG") + from mu_map.dataset.mock import MuMapMockDataset + from mu_map.logging import add_logging_args, get_logger_by_args + from mu_map.models.unet import UNet - model = UNet(in_channels=1, features=[8, 16]) - print(model) - dataset = MuMapMockDataset() - data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1) - data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1) + parser = argparse.ArgumentParser( + description="Train a UNet model to predict μ-maps from reconstructed scatter images", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Model Args + parser.add_argument( + "--features", + type=int, + nargs="+", + default=[8, 16], + help="number of features in the layers of the UNet structure", + ) + + # Dataset Args + # parser.add_argument("--features", type=int, nargs="+", default=[8, 16], help="number of features in the layers of the UNet structure") + + # Training Args + parser.add_argument( + "--output_dir", + type=str, + default="train_data", + help="directory in which results (snapshots and logs) of this training are saved", + ) + parser.add_argument( + "--epochs", + type=int, + default=10, + help="the number of epochs for which the model is trained", + ) + parser.add_argument( + "--device", + type=str, + default="cuda:0" if torch.cuda.is_available() else "cpu", + help="the device (cpu or gpu) with which the training is performed", + ) + parser.add_argument( + "--lr", type=float, default=0.1, help="the initial learning rate for training" + ) + parser.add_argument( + "--lr_decay_factor", + type=float, + default=0.99, + help="decay factor for the learning rate", + ) + parser.add_argument( + "--lr_decay_epoch", + type=int, + default=1, + help="frequency in epochs at which the learning rate is decayed", + ) + parser.add_argument( + "--snapshot_dir", + type=str, + default="snapshots", + help="directory under --output_dir where snapshots are stored", + ) + parser.add_argument( + "--snapshot_epoch", + type=int, + default=10, + help="frequency in epochs at which snapshots are stored", + ) + + # Logging Args + add_logging_args(parser, defaults={"--logfile": "train.log"}) + + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + + args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) + if not os.path.exists(args.snapshot_dir): + os.mkdir(args.snapshot_dir) + + args.logfile = os.path.join(args.output_dir, args.logfile) + + device = torch.device(args.device) + logger = get_logger_by_args(args) + + model = UNet(in_channels=1, features=args.features) + dataset = MuMapMockDataset(logger=logger) + data_loader_train = torch.utils.data.DataLoader( + dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1 + ) + data_loader_val = torch.utils.data.DataLoader( + dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1 + ) data_loaders = {"train": data_loader_train, "validation": data_loader_val} - training = Training(model, data_loaders, 10, logger) + training = Training( + model=model, + data_loaders=data_loaders, + epochs=args.epochs, + device=device, + lr=args.lr, + lr_decay_factor=args.lr_decay_factor, + lr_decay_epoch=args.lr_decay_epoch, + snapshot_dir=args.snapshot_dir, + snapshot_epoch=args.snapshot_epoch, + logger=logger, + ) training.run() - - -- GitLab