Skip to content
Snippets Groups Projects
Commit b36a8455 authored by Tamino Huxohl's avatar Tamino Huxohl
Browse files

make default training configurable via argparse

parent d66fd526
No related branches found
No related tags found
No related merge requests found
import os
from typing import Dict
import torch
class Training():
def __init__(self, model, data_loaders, epochs, logger):
from mu_map.logging import get_logger
class Training:
def __init__(
self,
model: torch.nn.Module,
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
device: torch.device,
lr: float,
lr_decay_factor: float,
lr_decay_epoch: int,
snapshot_dir: str,
snapshot_epoch: int,
logger=None,
):
self.model = model
self.data_loaders = data_loaders
self.epochs = epochs
self.device = torch.device("cpu")
self.snapshot_dir = "tmp"
self.snapshot_epoch = 5
self.loss_func = torch.nn.MSELoss()
self.device = device
# self.lr = 1e-3
# self.lr_decay_factor = 0.99
self.lr = 0.1
self.lr_decay_factor = 0.5
self.lr_decay_epoch = 1
self.lr = lr
self.lr_decay_factor = lr_decay_factor
self.lr_decay_epoch = lr_decay_epoch
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor)
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger
self.logger = logger if logger is not None else get_logger()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
)
self.loss_func = torch.nn.MSELoss(reduction="mean")
def run(self):
for epoch in range(1, self.epochs + 1):
logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...")
logger.debug(
f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
)
self._run_epoch(self.data_loaders["train"], phase="train")
loss_training = self._run_epoch(self.data_loaders["train"], phase="eval")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}")
loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}")
loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}"
)
loss_validation = self._run_epoch(
self.data_loaders["validation"], phase="val"
)
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}"
)
# ToDo: log outputs and time
_previous = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step()
logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}")
logger.debug(
f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}"
)
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(epoch)
logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}")
logger.debug(
f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
)
def _run_epoch(self, data_loader, phase):
logger.debug(f"Run epoch in phase {phase}")
self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0
loss_updates = 0
for i, (inputs, labels) in enumerate(data_loader):
print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r")
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
labels = labels.to(self.device)
......@@ -66,9 +96,9 @@ class Training():
loss.backward()
self.optimizer.step()
epoch_loss += loss.item() / inputs.shape[0]
return epoch_loss
epoch_loss += loss.item()
loss_updates += 1
return epoch_loss / loss_updates
def store_snapshot(self, epoch):
snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
......@@ -78,20 +108,113 @@ class Training():
if __name__ == "__main__":
from mu_map.data.mock import MuMapMockDataset
from mu_map.logging import get_logger
from mu_map.models.unet import UNet
import argparse
logger = get_logger(logfile="train.log", loglevel="DEBUG")
from mu_map.dataset.mock import MuMapMockDataset
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
model = UNet(in_channels=1, features=[8, 16])
print(model)
dataset = MuMapMockDataset()
data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model Args
parser.add_argument(
"--features",
type=int,
nargs="+",
default=[8, 16],
help="number of features in the layers of the UNet structure",
)
# Dataset Args
# parser.add_argument("--features", type=int, nargs="+", default=[8, 16], help="number of features in the layers of the UNet structure")
# Training Args
parser.add_argument(
"--output_dir",
type=str,
default="train_data",
help="directory in which results (snapshots and logs) of this training are saved",
)
parser.add_argument(
"--epochs",
type=int,
default=10,
help="the number of epochs for which the model is trained",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--lr", type=float, default=0.1, help="the initial learning rate for training"
)
parser.add_argument(
"--lr_decay_factor",
type=float,
default=0.99,
help="decay factor for the learning rate",
)
parser.add_argument(
"--lr_decay_epoch",
type=int,
default=1,
help="frequency in epochs at which the learning rate is decayed",
)
parser.add_argument(
"--snapshot_dir",
type=str,
default="snapshots",
help="directory under --output_dir where snapshots are stored",
)
parser.add_argument(
"--snapshot_epoch",
type=int,
default=10,
help="frequency in epochs at which snapshots are stored",
)
# Logging Args
add_logging_args(parser, defaults={"--logfile": "train.log"})
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
args.logfile = os.path.join(args.output_dir, args.logfile)
device = torch.device(args.device)
logger = get_logger_by_args(args)
model = UNet(in_channels=1, features=args.features)
dataset = MuMapMockDataset(logger=logger)
data_loader_train = torch.utils.data.DataLoader(
dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
)
data_loader_val = torch.utils.data.DataLoader(
dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1
)
data_loaders = {"train": data_loader_train, "validation": data_loader_val}
training = Training(model, data_loaders, 10, logger)
training = Training(
model=model,
data_loaders=data_loaders,
epochs=args.epochs,
device=device,
lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_epoch=args.lr_decay_epoch,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
logger=logger,
)
training.run()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment