Skip to content
Snippets Groups Projects
Commit aadd524d authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

first implementation of default training

parent 5ad12c55
No related branches found
No related tags found
No related merge requests found
import os
import torch
class Training(): class Training():
def __init__(self, epochs): def __init__(self, model, data_loaders, epochs, logger):
self.model = model
self.data_loaders = data_loaders
self.epochs = epochs self.epochs = epochs
self.device = torch.device("cpu")
self.snapshot_dir = "tmp"
self.snapshot_epoch = 5
self.loss_func = torch.nn.MSELoss()
# self.lr = 1e-3
# self.lr_decay_factor = 0.99
self.lr = 0.1
self.lr_decay_factor = 0.5
self.lr_decay_epoch = 1
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor)
self.logger = logger
def run(self): def run(self):
for epoch in range(1, self.epochs + 1): for epoch in range(1, self.epochs + 1):
self.run_epoch(self.data_loader["train"], phase="train") logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...")
loss_training = self.run_epoch(self.data_loader["train"], phase="eval") self._run_epoch(self.data_loaders["train"], phase="train")
loss_validation = self.run_epoch(self.data_loader["validation"], phase="eval")
loss_training = self._run_epoch(self.data_loaders["train"], phase="eval")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}")
loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}")
# ToDo: log outputs and time # ToDo: log outputs and time
_previous = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step() self.lr_scheduler.step()
logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}")
if epoch % self.snapshot_epoch: if epoch % self.snapshot_epoch:
self.store_snapshot(epoch) self.store_snapshot(epoch)
logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}")
def run_epoch(self, data_loader, phase): def _run_epoch(self, data_loader, phase):
logger.debug(f"Run epoch in phase {phase}")
self.model.train() if phase == "train" else self.model.eval() self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0 epoch_loss = 0
for inputs, labels in self.data_loader: for i, (inputs, labels) in enumerate(data_loader):
print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r")
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
labels = labels.to(self.device) labels = labels.to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"): with torch.set_grad_enabled(phase == "train"):
outputs = self.model(inputs) outputs = self.model(inputs)
loss = self.loss(outputs, labels) loss = self.loss_func(outputs, labels)
if phase == "train": if phase == "train":
loss.backward() loss.backward()
optimizer.step() self.optimizer.step()
epoch_loss += loss.item() / inputs.size[0] epoch_loss += loss.item() / inputs.shape[0]
return epoch_loss return epoch_loss
def store_snapshot(self, epoch): def store_snapshot(self, epoch):
pass snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(self.model.state_dict(), snapshot_file)
if __name__ == "__main__":
from mu_map.data.mock import MuMapMockDataset
from mu_map.logging import get_logger
from mu_map.models.unet import UNet
logger = get_logger(logfile="train.log", loglevel="DEBUG")
model = UNet(in_channels=1, features=[8, 16])
print(model)
dataset = MuMapMockDataset()
data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
data_loaders = {"train": data_loader_train, "validation": data_loader_val}
training = Training(model, data_loaders, 10, logger)
training.run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment