Skip to content
Snippets Groups Projects
distance.py 7.38 KiB
Newer Older
  • Learn to ignore specific revisions
  • from logging import Logger
    from typing import Optional
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    from mu_map.training.lib import TrainingParams, AbstractTraining
    
    from mu_map.training.loss import WeightedLoss
    
    class DistanceTraining(AbstractTraining):
    
        def __init__(
            self,
            epochs: int,
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            dataset: MuMapDataset,
            batch_size: int,
    
            device: torch.device,
            snapshot_dir: str,
            snapshot_epoch: int,
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            params: TrainingParams,
            loss_func: WeightedLoss,
    
            logger: Optional[Logger] = None,
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
        ):
    
            super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.training_params.append(params)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.model = params.model
    
        def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
    
            mu_maps_predicted = self.model(recons)
            loss = self.loss_func(mu_maps_predicted, mu_maps)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            loss.backward()
            return loss.item()
    
        def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            loss = torch.nn.functional.l1_loss(mu_maps_predicted, mu_maps)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            return loss.item()
    
    
    
    if __name__ == "__main__":
    
        import random
        import sys
    
        import numpy as np
    
        from mu_map.dataset.patches import MuMapPatchDataset
    
        from mu_map.dataset.normalization import (
            MeanNormTransform,
            MaxNormTransform,
            GaussianNormTransform,
        )
    
        from mu_map.logging import add_logging_args, get_logger_by_args
        from mu_map.models.unet import UNet
    
        parser = argparse.ArgumentParser(
            description="Train a UNet model to predict μ-maps from reconstructed scatter images",
            formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        )
    
        # Model Args
        parser.add_argument(
            "--features",
            type=int,
            nargs="+",
    
            default=[64, 128, 256, 512],
    
            help="number of features in the layers of the UNet structure",
        )
    
        # Dataset Args
    
        parser.add_argument(
            "--dataset_dir",
            type=str,
    
            default="data/second/",
    
            help="the directory where the dataset for training is found",
        )
    
        parser.add_argument(
            "--input_norm",
            type=str,
            choices=["none", "mean", "max", "gaussian"],
            default="mean",
            help="type of normalization applied to the reconstructions",
        )
    
        parser.add_argument(
            "--patch_size",
            type=int,
            default=32,
    
            help="the size of patches extracted for each reconstruction",
    
        )
        parser.add_argument(
            "--patch_offset",
            type=int,
            default=20,
    
            help="offset to ignore the border of the image",
    
        )
        parser.add_argument(
            "--number_of_patches",
            type=int,
            default=100,
    
            help="number of patches extracted for each image",
    
        )
        parser.add_argument(
            "--no_shuffle",
    
            action="store_true",
            help="do not shuffle patches in the dataset",
    
        parser.add_argument(
            "--seed",
            type=int,
            help="seed used for random number generation",
        )
    
        parser.add_argument(
            "--batch_size",
            type=int,
    
            help="the batch size used for training",
        )
    
        parser.add_argument(
            "--output_dir",
            type=str,
            default="train_data",
            help="directory in which results (snapshots and logs) of this training are saved",
        )
        parser.add_argument(
            "--epochs",
            type=int,
    
            default=100,
    
            help="the number of epochs for which the model is trained",
        )
        parser.add_argument(
            "--device",
            type=str,
            default="cuda:0" if torch.cuda.is_available() else "cpu",
            help="the device (cpu or gpu) with which the training is performed",
        )
    
        parser.add_argument(
            "--loss_func",
            type=str,
            default="l1",
            help="define the loss function used for training, e.g. 0.75*l1+0.25*gdl",
        )
    
        parser.add_argument(
            "--decay_lr",
            action="store_true",
            help="decay the learning rate",
        )
    
        parser.add_argument(
    
            "--lr", type=float, default=0.001, help="the initial learning rate for training"
    
        )
        parser.add_argument(
            "--lr_decay_factor",
            type=float,
            default=0.99,
            help="decay factor for the learning rate",
        )
        parser.add_argument(
            "--lr_decay_epoch",
            type=int,
            default=1,
            help="frequency in epochs at which the learning rate is decayed",
        )
        parser.add_argument(
            "--snapshot_dir",
            type=str,
            default="snapshots",
            help="directory under --output_dir where snapshots are stored",
        )
        parser.add_argument(
            "--snapshot_epoch",
            type=int,
            default=10,
            help="frequency in epochs at which snapshots are stored",
        )
    
        # Logging Args
        add_logging_args(parser, defaults={"--logfile": "train.log"})
    
        args = parser.parse_args()
    
        if not os.path.exists(args.output_dir):
            os.mkdir(args.output_dir)
    
        args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
        if not os.path.exists(args.snapshot_dir):
            os.mkdir(args.snapshot_dir)
    
        else:
            if len(os.listdir(args.snapshot_dir)) > 0:
                print(
                    f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
                )
                print(f"           Exit so that data is not accidentally overwritten!")
                exit(1)
    
    
        args.logfile = os.path.join(args.output_dir, args.logfile)
    
        device = torch.device(args.device)
        logger = get_logger_by_args(args)
    
        logger.info(args)
    
    
        args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
        logger.info(f"Seed: {args.seed}")
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        np.random.seed(args.seed)
    
    
        transform_normalization = None
        if args.input_norm == "mean":
            transform_normalization = MeanNormTransform()
        elif args.input_norm == "max":
            transform_normalization = MaxNormTransform()
        elif args.input_norm == "gaussian":
            transform_normalization = GaussianNormTransform()
    
    
        dataset = MuMapPatchDataset(
            args.dataset_dir,
            patches_per_image=args.number_of_patches,
            patch_size=args.patch_size,
            patch_offset=args.patch_offset,
            shuffle=not args.no_shuffle,
            transform_normalization=transform_normalization,
            logger=logger,
        )
    
    
        model = UNet(in_channels=1, features=args.features).to(device)
    
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
        lr_scheduler = (
            torch.optim.lr_scheduler.StepLR(
                optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
    
            if args.decay_lr
            else None
        )
        params = TrainingParams(name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)
    
        training = DistanceTraining(
    
            epochs=args.epochs,
    
            device=device,
            snapshot_dir=args.snapshot_dir,
            snapshot_epoch=args.snapshot_epoch,
    
        training.run()