import os from typing import Dict import torch from mu_map.logging import get_logger from mu_map.training.loss import GradientDifferenceLoss class Training: def __init__( self, model: torch.nn.Module, data_loaders: Dict[str, torch.utils.data.DataLoader], epochs: int, device: torch.device, lr: float, lr_decay_factor: float, lr_decay_epoch: int, snapshot_dir: str, snapshot_epoch: int, logger=None, ): self.model = model self.data_loaders = data_loaders self.epochs = epochs self.device = device self.lr = lr self.lr_decay_factor = lr_decay_factor self.lr_decay_epoch = lr_decay_epoch self.snapshot_dir = snapshot_dir self.snapshot_epoch = snapshot_epoch self.logger = logger if logger is not None else get_logger() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.lr_scheduler = torch.optim.lr_scheduler.StepLR( self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor ) # self.loss_func = torch.nn.MSELoss(reduction="mean") # self.loss_func = torch.nn.L1Loss(reduction="mean") _loss1 = torch.nn.MSELoss() _loss2 = GradientDifferenceLoss() def _loss_func(outputs, targets): return _loss1(outputs, targets) + _loss2(outputs, targets) self.loss_func = _loss_func def run(self): for epoch in range(1, self.epochs + 1): logger.debug( f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..." ) self._run_epoch(self.data_loaders["train"], phase="train") loss_training = self._run_epoch(self.data_loaders["train"], phase="val") logger.info( f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}" ) loss_validation = self._run_epoch( self.data_loaders["validation"], phase="val" ) logger.info( f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}" ) # ToDo: log outputs and time _previous = self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.step() logger.debug( f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}" ) if epoch % self.snapshot_epoch == 0: self.store_snapshot(epoch) logger.debug( f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}" ) def _run_epoch(self, data_loader, phase): logger.debug(f"Run epoch in phase {phase}") self.model.train() if phase == "train" else self.model.eval() epoch_loss = 0 loss_updates = 0 for i, (inputs, labels) in enumerate(data_loader): print( f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r", ) inputs = inputs.to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() with torch.set_grad_enabled(phase == "train"): outputs = self.model(inputs) loss = self.loss_func(outputs, labels) if phase == "train": loss.backward() self.optimizer.step() epoch_loss += loss.item() loss_updates += 1 return epoch_loss / loss_updates def store_snapshot(self, epoch): snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth" snapshot_file = os.path.join(self.snapshot_dir, snapshot_file) logger.debug(f"Store snapshot at {snapshot_file}") torch.save(self.model.state_dict(), snapshot_file) if __name__ == "__main__": import argparse import random import sys import numpy as np from mu_map.dataset.patches import MuMapPatchDataset from mu_map.dataset.normalization import ( MeanNormTransform, MaxNormTransform, GaussianNormTransform, ) from mu_map.dataset.transform import ScaleTransform from mu_map.logging import add_logging_args, get_logger_by_args from mu_map.models.unet import UNet parser = argparse.ArgumentParser( description="Train a UNet model to predict μ-maps from reconstructed scatter images", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) # Model Args parser.add_argument( "--features", type=int, nargs="+", default=[64, 128, 256, 512], help="number of features in the layers of the UNet structure", ) # Dataset Args parser.add_argument( "--dataset_dir", type=str, default="data/initial/", help="the directory where the dataset for training is found", ) parser.add_argument( "--output_scale", type=float, default=1.0, help="scale the attenuation map by this coefficient", ) parser.add_argument( "--input_norm", type=str, choices=["none", "mean", "max", "gaussian"], default="mean", help="type of normalization applied to the reconstructions", ) parser.add_argument( "--patch_size", type=int, default=32, help="the size of patches extracted for each reconstruction", ) parser.add_argument( "--patch_offset", type=int, default=20, help="offset to ignore the border of the image", ) parser.add_argument( "--number_of_patches", type=int, default=100, help="number of patches extracted for each image", ) parser.add_argument( "--no_shuffle", action="store_true", help="do not shuffle patches in the dataset", ) # Training Args parser.add_argument( "--seed", type=int, help="seed used for random number generation", ) parser.add_argument( "--batch_size", type=int, default=64, help="the batch size used for training", ) parser.add_argument( "--output_dir", type=str, default="train_data", help="directory in which results (snapshots and logs) of this training are saved", ) parser.add_argument( "--epochs", type=int, default=100, help="the number of epochs for which the model is trained", ) parser.add_argument( "--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="the device (cpu or gpu) with which the training is performed", ) parser.add_argument( "--lr", type=float, default=0.001, help="the initial learning rate for training" ) parser.add_argument( "--lr_decay_factor", type=float, default=0.99, help="decay factor for the learning rate", ) parser.add_argument( "--lr_decay_epoch", type=int, default=1, help="frequency in epochs at which the learning rate is decayed", ) parser.add_argument( "--snapshot_dir", type=str, default="snapshots", help="directory under --output_dir where snapshots are stored", ) parser.add_argument( "--snapshot_epoch", type=int, default=10, help="frequency in epochs at which snapshots are stored", ) # Logging Args add_logging_args(parser, defaults={"--logfile": "train.log"}) args = parser.parse_args() if not os.path.exists(args.output_dir): os.mkdir(args.output_dir) args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir) if not os.path.exists(args.snapshot_dir): os.mkdir(args.snapshot_dir) else: if len(os.listdir(args.snapshot_dir)) > 0: print( f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!" ) print(f" Exit so that data is not accidentally overwritten!") exit(1) args.logfile = os.path.join(args.output_dir, args.logfile) device = torch.device(args.device) logger = get_logger_by_args(args) logger.info(args) args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) logger.info(f"Seed: {args.seed}") random.seed(args.seed) torch.manual_seed(args.seed) np.random.seed(args.seed) model = UNet(in_channels=1, features=args.features) model = model.to(device) transform_normalization = None if args.input_norm == "mean": transform_normalization = MeanNormTransform() elif args.input_norm == "max": transform_normalization = MaxNormTransform() elif args.input_norm == "gaussian": transform_normalization = GaussianNormTransform() transform_augmentation = ScaleTransform(scale_outputs=args.output_scale) data_loaders = {} for split in ["train", "validation"]: dataset = MuMapPatchDataset( args.dataset_dir, patches_per_image=args.number_of_patches, patch_size=args.patch_size, patch_offset=args.patch_offset, shuffle=not args.no_shuffle, split_name=split, transform_normalization=transform_normalization, transform_augmentation=transform_augmentation, logger=logger, ) data_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=1, ) data_loaders[split] = data_loader training = Training( model=model, data_loaders=data_loaders, epochs=args.epochs, device=device, lr=args.lr, lr_decay_factor=args.lr_decay_factor, lr_decay_epoch=args.lr_decay_epoch, snapshot_dir=args.snapshot_dir, snapshot_epoch=args.snapshot_epoch, logger=logger, ) training.run()