Skip to content
Snippets Groups Projects
default.py 3.65 KiB
Newer Older
  • Learn to ignore specific revisions
  • Tamino Huxohl's avatar
    Tamino Huxohl committed
    
    class Training():
    
    
        def __init__(self, model, data_loaders, epochs, logger):
            self.model = model
            self.data_loaders = data_loaders
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.epochs = epochs
    
            self.device = torch.device("cpu")
            self.snapshot_dir = "tmp"
            self.snapshot_epoch = 5
            self.loss_func = torch.nn.MSELoss()
    
            # self.lr = 1e-3
            # self.lr_decay_factor = 0.99
            self.lr = 0.1
            self.lr_decay_factor = 0.5
            self.lr_decay_epoch = 1
    
            self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
            self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor)
    
            self.logger = logger
    
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
        def run(self):
            for epoch in range(1, self.epochs + 1):
    
                logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...")
                self._run_epoch(self.data_loaders["train"], phase="train")
    
                loss_training = self._run_epoch(self.data_loaders["train"], phase="eval")
                logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}")
                loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval")
                logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
                # ToDo: log outputs and time
    
                _previous = self.lr_scheduler.get_last_lr()[0]
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                self.lr_scheduler.step()
    
                logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}")
    
                if epoch % self.snapshot_epoch == 0:
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                    self.store_snapshot(epoch)
    
    
                logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}")
    
    
        def _run_epoch(self, data_loader, phase):
            logger.debug(f"Run epoch in phase {phase}")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            self.model.train() if phase == "train" else self.model.eval()
    
            epoch_loss = 0
    
            for i, (inputs, labels) in enumerate(data_loader):
                print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r")
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
    
                self.optimizer.zero_grad()
                with torch.set_grad_enabled(phase == "train"):
                    outputs = self.model(inputs)
    
                    loss = self.loss_func(outputs, labels)
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
    
                    if phase == "train":
                        loss.backward()
    
                        self.optimizer.step()
    
                epoch_loss += loss.item() / inputs.shape[0]
    
    Tamino Huxohl's avatar
    Tamino Huxohl committed
            return epoch_loss
    
    
        def store_snapshot(self, epoch):
    
            snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
            snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
            logger.debug(f"Store snapshot at {snapshot_file}")
            torch.save(self.model.state_dict(), snapshot_file)
    
    
    if __name__ == "__main__":
        from mu_map.data.mock import MuMapMockDataset
        from mu_map.logging import get_logger
        from mu_map.models.unet import UNet
    
        logger = get_logger(logfile="train.log", loglevel="DEBUG")
    
        model = UNet(in_channels=1, features=[8, 16])
        print(model)
        dataset = MuMapMockDataset()
        data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
        data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
        data_loaders = {"train": data_loader_train, "validation": data_loader_val}
    
        training = Training(model, data_loaders, 10, logger)
        training.run()