Skip to content
Snippets Groups Projects
default.py 10.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • from typing import Dict
    
    from mu_map.logging import get_logger
    
    from mu_map.training.loss import GradientDifferenceLoss
    
    
    
    class Training:
        def __init__(
            self,
            model: torch.nn.Module,
            data_loaders: Dict[str, torch.utils.data.DataLoader],
            epochs: int,
            device: torch.device,
            lr: float,
            lr_decay_factor: float,
            lr_decay_epoch: int,
            snapshot_dir: str,
            snapshot_epoch: int,
            logger=None,
        ):
    
            self.model = model
            self.data_loaders = data_loaders
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.epochs = epochs
    
            self.device = device
    
            self.lr = lr
            self.lr_decay_factor = lr_decay_factor
            self.lr_decay_epoch = lr_decay_epoch
    
            self.snapshot_dir = snapshot_dir
            self.snapshot_epoch = snapshot_epoch
    
            self.logger = logger if logger is not None else get_logger()
    
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
                self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
            )
    
            # self.loss_func = torch.nn.MSELoss(reduction="mean")
    
            # self.loss_func = torch.nn.L1Loss(reduction="mean")
            _loss1 = torch.nn.MSELoss()
            _loss2 = GradientDifferenceLoss()
            def _loss_func(outputs, targets):
                return _loss1(outputs, targets) + _loss2(outputs, targets)
            self.loss_func = _loss_func
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        def run(self):
            for epoch in range(1, self.epochs + 1):
    
                logger.debug(
                    f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
                )
    
                self._run_epoch(self.data_loaders["train"], phase="train")
    
    
                loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
                logger.info(
                    f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}"
                )
                loss_validation = self._run_epoch(
                    self.data_loaders["validation"], phase="val"
                )
                logger.info(
                    f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}"
                )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
                # ToDo: log outputs and time
    
                _previous = self.lr_scheduler.get_last_lr()[0]
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                self.lr_scheduler.step()
    
                logger.debug(
                    f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}"
                )
    
                if epoch % self.snapshot_epoch == 0:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    self.store_snapshot(epoch)
    
    
                logger.debug(
                    f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
                )
    
        def _run_epoch(self, data_loader, phase):
            logger.debug(f"Run epoch in phase {phase}")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.model.train() if phase == "train" else self.model.eval()
    
            epoch_loss = 0
    
            for i, (inputs, labels) in enumerate(data_loader):
    
                print(
                    f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                    end="\r",
                )
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
    
                self.optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    outputs = self.model(inputs)
    
                    loss = self.loss_func(outputs, labels)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
                    if phase == "train":
                        loss.backward()
    
                        self.optimizer.step()
    
                epoch_loss += loss.item()
                loss_updates += 1
            return epoch_loss / loss_updates
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
        def store_snapshot(self, epoch):
    
            snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
            snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
            logger.debug(f"Store snapshot at {snapshot_file}")
            torch.save(self.model.state_dict(), snapshot_file)
    
    
    if __name__ == "__main__":
    
        import random
        import sys
    
        import numpy as np
    
        from mu_map.dataset.patches import MuMapPatchDataset
    
        from mu_map.dataset.normalization import (
            MeanNormTransform,
            MaxNormTransform,
            GaussianNormTransform,
        )
        from mu_map.dataset.transform import ScaleTransform
    
        from mu_map.logging import add_logging_args, get_logger_by_args
        from mu_map.models.unet import UNet
    
        parser = argparse.ArgumentParser(
            description="Train a UNet model to predict μ-maps from reconstructed scatter images",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
    
        # Model Args
        parser.add_argument(
            "--features",
            type=int,
            nargs="+",
    
            default=[64, 128, 256, 512],
    
            help="number of features in the layers of the UNet structure",
        )
    
        # Dataset Args
    
        parser.add_argument(
            "--dataset_dir",
            type=str,
            default="data/initial/",
            help="the directory where the dataset for training is found",
        )
    
        parser.add_argument(
            "--output_scale",
            type=float,
            default=1.0,
            help="scale the attenuation map by this coefficient",
        )
        parser.add_argument(
            "--input_norm",
            type=str,
            choices=["none", "mean", "max", "gaussian"],
            default="mean",
            help="type of normalization applied to the reconstructions",
        )
    
        parser.add_argument(
            "--patch_size",
            type=int,
            default=32,
    
            help="the size of patches extracted for each reconstruction",
    
        )
        parser.add_argument(
            "--patch_offset",
            type=int,
            default=20,
    
            help="offset to ignore the border of the image",
    
        )
        parser.add_argument(
            "--number_of_patches",
            type=int,
            default=100,
    
            help="number of patches extracted for each image",
    
        )
        parser.add_argument(
            "--no_shuffle",
    
            action="store_true",
            help="do not shuffle patches in the dataset",
    
        parser.add_argument(
            "--seed",
            type=int,
            help="seed used for random number generation",
        )
    
        parser.add_argument(
            "--batch_size",
            type=int,
    
            help="the batch size used for training",
        )
    
        parser.add_argument(
            "--output_dir",
            type=str,
            default="train_data",
            help="directory in which results (snapshots and logs) of this training are saved",
        )
        parser.add_argument(
            "--epochs",
            type=int,
    
            default=100,
    
            help="the number of epochs for which the model is trained",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda:0" if torch.cuda.is_available() else "cpu",
            help="the device (cpu or gpu) with which the training is performed",
        )
        parser.add_argument(
    
            "--lr", type=float, default=0.001, help="the initial learning rate for training"
    
        )
        parser.add_argument(
            "--lr_decay_factor",
            type=float,
            default=0.99,
            help="decay factor for the learning rate",
        )
        parser.add_argument(
            "--lr_decay_epoch",
            type=int,
            default=1,
            help="frequency in epochs at which the learning rate is decayed",
        )
        parser.add_argument(
            "--snapshot_dir",
            type=str,
            default="snapshots",
            help="directory under --output_dir where snapshots are stored",
        )
        parser.add_argument(
            "--snapshot_epoch",
            type=int,
            default=10,
            help="frequency in epochs at which snapshots are stored",
        )
    
        # Logging Args
        add_logging_args(parser, defaults={"--logfile": "train.log"})
    
        args = parser.parse_args()
    
        if not os.path.exists(args.output_dir):
            os.mkdir(args.output_dir)
    
        args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
        if not os.path.exists(args.snapshot_dir):
            os.mkdir(args.snapshot_dir)
    
        else:
            if len(os.listdir(args.snapshot_dir)) > 0:
                print(
                    f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
                )
                print(f"           Exit so that data is not accidentally overwritten!")
                exit(1)
    
    
        args.logfile = os.path.join(args.output_dir, args.logfile)
    
        device = torch.device(args.device)
        logger = get_logger_by_args(args)
    
        logger.info(args)
    
    
        args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
        logger.info(f"Seed: {args.seed}")
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
    
    
        model = UNet(in_channels=1, features=args.features)
    
        model = model.to(device)
    
        transform_normalization = None
        if args.input_norm == "mean":
            transform_normalization = MeanNormTransform()
        elif args.input_norm == "max":
            transform_normalization = MaxNormTransform()
        elif args.input_norm == "gaussian":
            transform_normalization = GaussianNormTransform()
    
        transform_augmentation = ScaleTransform(scale_outputs=args.output_scale)
    
    
        data_loaders = {}
        for split in ["train", "validation"]:
            dataset = MuMapPatchDataset(
                args.dataset_dir,
    
                patches_per_image=args.number_of_patches,
                patch_size=args.patch_size,
                patch_offset=args.patch_offset,
                shuffle=not args.no_shuffle,
    
                transform_normalization=transform_normalization,
                transform_augmentation=transform_augmentation,
    
                logger=logger,
            )
            data_loader = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=args.batch_size,
                shuffle=True,
                pin_memory=True,
                num_workers=1,
            )
            data_loaders[split] = data_loader
    
        training = Training(
            model=model,
            data_loaders=data_loaders,
            epochs=args.epochs,
            device=device,
            lr=args.lr,
            lr_decay_factor=args.lr_decay_factor,
            lr_decay_epoch=args.lr_decay_epoch,
            snapshot_dir=args.snapshot_dir,
            snapshot_epoch=args.snapshot_epoch,
            logger=logger,
        )
    
        training.run()