Skip to content
Snippets Groups Projects
Commit b36a8455 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

make default training configurable via argparse

parent d66fd526
No related branches found
No related tags found
No related merge requests found
import os import os
from typing import Dict
import torch import torch
class Training(): from mu_map.logging import get_logger
def __init__(self, model, data_loaders, epochs, logger):
class Training:
def __init__(
self,
model: torch.nn.Module,
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
device: torch.device,
lr: float,
lr_decay_factor: float,
lr_decay_epoch: int,
snapshot_dir: str,
snapshot_epoch: int,
logger=None,
):
self.model = model self.model = model
self.data_loaders = data_loaders self.data_loaders = data_loaders
self.epochs = epochs self.epochs = epochs
self.device = torch.device("cpu") self.device = device
self.snapshot_dir = "tmp"
self.snapshot_epoch = 5
self.loss_func = torch.nn.MSELoss()
# self.lr = 1e-3 self.lr = lr
# self.lr_decay_factor = 0.99 self.lr_decay_factor = lr_decay_factor
self.lr = 0.1 self.lr_decay_epoch = lr_decay_epoch
self.lr_decay_factor = 0.5
self.lr_decay_epoch = 1
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr) self.snapshot_dir = snapshot_dir
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor) self.snapshot_epoch = snapshot_epoch
self.logger = logger self.logger = logger if logger is not None else get_logger()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
)
self.loss_func = torch.nn.MSELoss(reduction="mean")
def run(self): def run(self):
for epoch in range(1, self.epochs + 1): for epoch in range(1, self.epochs + 1):
logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...") logger.debug(
f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
)
self._run_epoch(self.data_loaders["train"], phase="train") self._run_epoch(self.data_loaders["train"], phase="train")
loss_training = self._run_epoch(self.data_loaders["train"], phase="eval") loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}") logger.info(
loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval") f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}"
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}") )
loss_validation = self._run_epoch(
self.data_loaders["validation"], phase="val"
)
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}"
)
# ToDo: log outputs and time # ToDo: log outputs and time
_previous = self.lr_scheduler.get_last_lr()[0] _previous = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step() self.lr_scheduler.step()
logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}") logger.debug(
f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}"
)
if epoch % self.snapshot_epoch == 0: if epoch % self.snapshot_epoch == 0:
self.store_snapshot(epoch) self.store_snapshot(epoch)
logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}") logger.debug(
f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
)
def _run_epoch(self, data_loader, phase): def _run_epoch(self, data_loader, phase):
logger.debug(f"Run epoch in phase {phase}") logger.debug(f"Run epoch in phase {phase}")
self.model.train() if phase == "train" else self.model.eval() self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0 epoch_loss = 0
loss_updates = 0
for i, (inputs, labels) in enumerate(data_loader): for i, (inputs, labels) in enumerate(data_loader):
print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r") print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
labels = labels.to(self.device) labels = labels.to(self.device)
...@@ -66,9 +96,9 @@ class Training(): ...@@ -66,9 +96,9 @@ class Training():
loss.backward() loss.backward()
self.optimizer.step() self.optimizer.step()
epoch_loss += loss.item() / inputs.shape[0] epoch_loss += loss.item()
return epoch_loss loss_updates += 1
return epoch_loss / loss_updates
def store_snapshot(self, epoch): def store_snapshot(self, epoch):
snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth" snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
...@@ -78,20 +108,113 @@ class Training(): ...@@ -78,20 +108,113 @@ class Training():
if __name__ == "__main__": if __name__ == "__main__":
from mu_map.data.mock import MuMapMockDataset import argparse
from mu_map.logging import get_logger
from mu_map.models.unet import UNet
logger = get_logger(logfile="train.log", loglevel="DEBUG") from mu_map.dataset.mock import MuMapMockDataset
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
model = UNet(in_channels=1, features=[8, 16]) parser = argparse.ArgumentParser(
print(model) description="Train a UNet model to predict μ-maps from reconstructed scatter images",
dataset = MuMapMockDataset() formatter_class=argparse.ArgumentDefaultsHelpFormatter,
data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1) )
data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
# Model Args
parser.add_argument(
"--features",
type=int,
nargs="+",
default=[8, 16],
help="number of features in the layers of the UNet structure",
)
# Dataset Args
# parser.add_argument("--features", type=int, nargs="+", default=[8, 16], help="number of features in the layers of the UNet structure")
# Training Args
parser.add_argument(
"--output_dir",
type=str,
default="train_data",
help="directory in which results (snapshots and logs) of this training are saved",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="the number of epochs for which the model is trained",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--lr", type=float, default=0.1, help="the initial learning rate for training"
)
parser.add_argument(
"--lr_decay_factor",
type=float,
default=0.99,
help="decay factor for the learning rate",
)
parser.add_argument(
"--lr_decay_epoch",
type=int,
default=1,
help="frequency in epochs at which the learning rate is decayed",
)
parser.add_argument(
"--snapshot_dir",
type=str,
default="snapshots",
help="directory under --output_dir where snapshots are stored",
)
parser.add_argument(
"--snapshot_epoch",
type=int,
default=10,
help="frequency in epochs at which snapshots are stored",
)
# Logging Args
add_logging_args(parser, defaults={"--logfile": "train.log"})
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
args.logfile = os.path.join(args.output_dir, args.logfile)
device = torch.device(args.device)
logger = get_logger_by_args(args)
model = UNet(in_channels=1, features=args.features)
dataset = MuMapMockDataset(logger=logger)
data_loader_train = torch.utils.data.DataLoader(
dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
)
data_loader_val = torch.utils.data.DataLoader(
dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
)
data_loaders = {"train": data_loader_train, "validation": data_loader_val} data_loaders = {"train": data_loader_train, "validation": data_loader_val}
training = Training(model, data_loaders, 10, logger) training = Training(
model=model,
data_loaders=data_loaders,
epochs=args.epochs,
device=device,
lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_epoch=args.lr_decay_epoch,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
logger=logger,
)
training.run() training.run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment