From aadd524d309a6462085565f41a6bd479d0d366da Mon Sep 17 00:00:00 2001 From: Tamino Huxohl <thuxohl@techfak.uni-bielefeld.de> Date: Mon, 26 Sep 2022 15:04:14 +0200 Subject: [PATCH] first implementation of default training --- mu_map/training/default.py | 73 ++++++++++++++++++++++++++++++++------ 1 file changed, 63 insertions(+), 10 deletions(-) diff --git a/mu_map/training/default.py b/mu_map/training/default.py index 1533098..d0d627b 100644 --- a/mu_map/training/default.py +++ b/mu_map/training/default.py @@ -1,44 +1,97 @@ +import os +import torch class Training(): - def __init__(self, epochs): + def __init__(self, model, data_loaders, epochs, logger): + self.model = model + self.data_loaders = data_loaders self.epochs = epochs + self.device = torch.device("cpu") + self.snapshot_dir = "tmp" + self.snapshot_epoch = 5 + self.loss_func = torch.nn.MSELoss() + + # self.lr = 1e-3 + # self.lr_decay_factor = 0.99 + self.lr = 0.1 + self.lr_decay_factor = 0.5 + self.lr_decay_epoch = 1 + + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) + self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor) + + self.logger = logger + def run(self): for epoch in range(1, self.epochs + 1): - self.run_epoch(self.data_loader["train"], phase="train") - loss_training = self.run_epoch(self.data_loader["train"], phase="eval") - loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval") + logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...") + self._run_epoch(self.data_loaders["train"], phase="train") + + loss_training = self._run_epoch(self.data_loaders["train"], phase="eval") + logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}") + loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval") + logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}") # ToDo: log outputs and time + _previous = self.lr_scheduler.get_last_lr()[0] self.lr_scheduler.step() + logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}") if epoch % self.snapshot_epoch: self.store_snapshot(epoch) + logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}") + - def run_epoch(self, data_loader, phase): + def _run_epoch(self, data_loader, phase): + logger.debug(f"Run epoch in phase {phase}") self.model.train() if phase == "train" else self.model.eval() epoch_loss = 0 - for inputs, labels in self.data_loader: + for i, (inputs, labels) in enumerate(data_loader): + print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r") inputs = inputs.to(self.device) labels = labels.to(self.device) self.optimizer.zero_grad() with torch.set_grad_enabled(phase == "train"): outputs = self.model(inputs) - loss = self.loss(outputs, labels) + loss = self.loss_func(outputs, labels) if phase == "train": loss.backward() - optimizer.step() + self.optimizer.step() - epoch_loss += loss.item() / inputs.size[0] + epoch_loss += loss.item() / inputs.shape[0] return epoch_loss def store_snapshot(self, epoch): - pass + snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth" + snapshot_file = os.path.join(self.snapshot_dir, snapshot_file) + logger.debug(f"Store snapshot at {snapshot_file}") + torch.save(self.model.state_dict(), snapshot_file) + + +if __name__ == "__main__": + from mu_map.data.mock import MuMapMockDataset + from mu_map.logging import get_logger + from mu_map.models.unet import UNet + + logger = get_logger(logfile="train.log", loglevel="DEBUG") + + model = UNet(in_channels=1, features=[8, 16]) + print(model) + dataset = MuMapMockDataset() + data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1) + data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1) + data_loaders = {"train": data_loader_train, "validation": data_loader_val} + + training = Training(model, data_loaders, 10, logger) + training.run() + + -- GitLab