import os import torch class Training(): def __init__(self, model, data_loaders, epochs, logger): self.model = model self.data_loaders = data_loaders self.epochs = epochs self.device = torch.device("cpu") self.snapshot_dir = "tmp" self.snapshot_epoch = 5 self.loss_func = torch.nn.MSELoss() # self.lr = 1e-3 # self.lr_decay_factor = 0.99 self.lr = 0.1 self.lr_decay_factor = 0.5 self.lr_decay_epoch = 1 self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor) self.logger = logger def run(self): for epoch in range(1, self.epochs + 1): logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...") self._run_epoch(self.data_loaders["train"], phase="train") loss_training = self._run_epoch(self.data_loaders["train"], phase="eval") logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}") loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval") logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}") # ToDo: log outputs and time _previous = self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.step() logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}") if epoch % self.snapshot_epoch: self.store_snapshot(epoch) logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}") def _run_epoch(self, data_loader, phase): logger.debug(f"Run epoch in phase {phase}") self.model.train() if phase == "train" else self.model.eval() epoch_loss = 0 for i, (inputs, labels) in enumerate(data_loader): print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r") inputs = inputs.to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() with torch.set_grad_enabled(phase == "train"): outputs = self.model(inputs) loss = self.loss_func(outputs, labels) if phase == "train": loss.backward() self.optimizer.step() epoch_loss += loss.item() / inputs.shape[0] return epoch_loss def store_snapshot(self, epoch): snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth" snapshot_file = os.path.join(self.snapshot_dir, snapshot_file) logger.debug(f"Store snapshot at {snapshot_file}") torch.save(self.model.state_dict(), snapshot_file) if __name__ == "__main__": from mu_map.data.mock import MuMapMockDataset from mu_map.logging import get_logger from mu_map.models.unet import UNet logger = get_logger(logfile="train.log", loglevel="DEBUG") model = UNet(in_channels=1, features=[8, 16]) print(model) dataset = MuMapMockDataset() data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1) data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1) data_loaders = {"train": data_loader_train, "validation": data_loader_val} training = Training(model, data_loaders, 10, logger) training.run()