Skip to content
Snippets Groups Projects
lib.py 4.76 KiB
Newer Older
  • Learn to ignore specific revisions
  • from dataclasses import dataclass
    
    import os
    
    
    import torch
    from torch import Tensor
    
    from mu_map.dataset.default import MuMapDataset
    
    @dataclass
    class TrainingParams:
        name: str
        model: torch.nn.Module
        optimizer: torch.optim.Optimizer
        lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler]
    
    
    
        def __init__(
            self,
            epochs: int,
            dataset: MuMapDataset,
            batch_size: int,
            device: torch.device,
            snapshot_dir: str,
            snapshot_epoch: int,
    
        ):
            self.epochs = epochs
            self.batch_size = batch_size
            self.dataset = dataset
            self.device = device
    
            self.snapshot_dir = snapshot_dir
            self.snapshot_epoch = snapshot_epoch
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.logger = (
                logger if logger is not None else get_logger(name=self.__class__.__name__)
            )
    
    
            self.training_params = []
    
            self.data_loaders = dict(
                [
                    (
                        split_name,
                        torch.utils.data.DataLoader(
                            dataset.split_copy(split_name),
                            batch_size=self.batch_size,
                            shuffle=True,
                            pin_memory=True,
                            num_workers=1,
                        ),
                    )
                    for split_name in ["train", "validation"]
                ]
            )
    
    
        def run(self) -> float:
            loss_val_min = sys.maxsize
            for epoch in range(1, self.epochs + 1):
                str_epoch = f"{str(epoch):>{len(str(self.epochs))}}"
                self.logger.debug(f"Run epoch {str_epoch}/{self.epochs} ...")
    
                loss_train = self._train_epoch()
                self.logger.info(
                    f"Epoch {str_epoch}/{self.epochs} - Loss train: {loss_train:.6f}"
                )
                loss_val = self._eval_epoch()
                self.logger.info(
                    f"Epoch {str_epoch}/{self.epochs} - Loss validation: {loss_val:.6f}"
                )
    
                if loss_val < loss_val_min:
                    loss_val_min = loss_val
                    self.logger.info(
                        f"Store snapshot val_min of epoch {str_epoch} with minimal validation loss"
                    )
                    self.store_snapshot("val_min")
    
                if epoch % self.snapshot_epoch == 0:
    
                    self.store_snapshot(f"{epoch:0{len(str(self.epochs))}d}")
    
    
                for param in self.training_params:
                    if param.lr_scheduler is not None:
                        param.lr_scheduler.step()
            return loss_val_min
    
    
        def _after_train_batch(self):
            """
            Function called after the loss computation on a batch during training.
            It is responsible for stepping all optimizers.
            """
            for param in self.training_params:
                param.optimizer.step()
    
    
        def _train_epoch(self):
            torch.set_grad_enabled(True)
            for param in self.training_params:
                param.model.train()
    
            loss = 0.0
            data_loader = self.data_loaders["train"]
            for i, (inputs, targets) in enumerate(data_loader):
                print(
                    f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                    end="\r",
                )
    
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
    
                for param in self.training_params:
                    param.optimizer.zero_grad()
    
    
            return loss / len(data_loader)
    
    
        def _eval_epoch(self):
    
            torch.set_grad_enabled(False)
    
            for param in self.training_params:
                param.model.eval()
    
    
            loss = 0.0
            data_loader = self.data_loaders["validation"]
            for i, (inputs, targets) in enumerate(data_loader):
                print(
                    f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
                    end="\r",
                )
    
                inputs = inputs.to(self.device)
                targets = targets.to(self.device)
    
    
            return loss / len(data_loader)
    
        def store_snapshot(self, prefix: str):
            for param in self.training_params:
    
                    self.snapshot_dir, f"{prefix}_{param.name.lower()}.pth"
    
                self.logger.debug(f"Store snapshot at {snapshot_file}")
                torch.save(param.model.state_dict(), snapshot_file)
    
        def _train_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
            return 0
    
        def _eval_batch(self, inputs: torch.Tensor, targets: torch.Tensor) -> float:
            return 0