Skip to content
Snippets Groups Projects
Commit aadd524d authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

first implementation of default training

parent 5ad12c55
No related branches found
No related tags found
No related merge requests found
import os
import torch
class Training():
def __init__(self, epochs):
def __init__(self, model, data_loaders, epochs, logger):
self.model = model
self.data_loaders = data_loaders
self.epochs = epochs
self.device = torch.device("cpu")
self.snapshot_dir = "tmp"
self.snapshot_epoch = 5
self.loss_func = torch.nn.MSELoss()
# self.lr = 1e-3
# self.lr_decay_factor = 0.99
self.lr = 0.1
self.lr_decay_factor = 0.5
self.lr_decay_epoch = 1
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor)
self.logger = logger
def run(self):
for epoch in range(1, self.epochs + 1):
self.run_epoch(self.data_loader["train"], phase="train")
loss_training = self.run_epoch(self.data_loader["train"], phase="eval")
loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval")
logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...")
self._run_epoch(self.data_loaders["train"], phase="train")
loss_training = self._run_epoch(self.data_loaders["train"], phase="eval")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}")
loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}")
# ToDo: log outputs and time
_previous = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step()
logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}")
if epoch % self.snapshot_epoch:
self.store_snapshot(epoch)
logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}")
def run_epoch(self, data_loader, phase):
def _run_epoch(self, data_loader, phase):
logger.debug(f"Run epoch in phase {phase}")
self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0
for inputs, labels in self.data_loader:
for i, (inputs, labels) in enumerate(data_loader):
print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r")
inputs = inputs.to(self.device)
labels = labels.to(self.device)
self.optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"):
outputs = self.model(inputs)
loss = self.loss(outputs, labels)
loss = self.loss_func(outputs, labels)
if phase == "train":
loss.backward()
optimizer.step()
self.optimizer.step()
epoch_loss += loss.item() / inputs.size[0]
epoch_loss += loss.item() / inputs.shape[0]
return epoch_loss
def store_snapshot(self, epoch):
pass
snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(self.model.state_dict(), snapshot_file)
if __name__ == "__main__":
from mu_map.data.mock import MuMapMockDataset
from mu_map.logging import get_logger
from mu_map.models.unet import UNet
logger = get_logger(logfile="train.log", loglevel="DEBUG")
model = UNet(in_channels=1, features=[8, 16])
print(model)
dataset = MuMapMockDataset()
data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
data_loaders = {"train": data_loader_train, "validation": data_loader_val}
training = Training(model, data_loaders, 10, logger)
training.run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment