-
Tamino Huxohl authoredTamino Huxohl authored
default.py 9.77 KiB
import os
from typing import Dict
import torch
from mu_map.logging import get_logger
from mu_map.training.loss import WeightedLoss
class Training:
def __init__(
self,
model: torch.nn.Module,
data_loaders: Dict[str, torch.utils.data.DataLoader],
epochs: int,
device: torch.device,
loss_func: WeightedLoss,
lr: float,
lr_decay_factor: float,
lr_decay_epoch: int,
snapshot_dir: str,
snapshot_epoch: int,
logger=None,
):
self.model = model
self.data_loaders = data_loaders
self.epochs = epochs
self.device = device
self.lr = lr
self.lr_decay_factor = lr_decay_factor
self.lr_decay_epoch = lr_decay_epoch
self.snapshot_dir = snapshot_dir
self.snapshot_epoch = snapshot_epoch
self.logger = logger if logger is not None else get_logger()
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(
self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor
)
self.loss_func = loss_func
def run(self):
for epoch in range(1, self.epochs + 1):
logger.debug(
f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ..."
)
self._run_epoch(self.data_loaders["train"], phase="train")
loss_training = self._run_epoch(self.data_loaders["train"], phase="val")
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss train: {loss_training:.6f}"
)
loss_validation = self._run_epoch(
self.data_loaders["validation"], phase="val"
)
logger.info(
f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss validation: {loss_validation:.6f}"
)
_previous = self.lr_scheduler.get_last_lr()[0]
self.lr_scheduler.step()
logger.debug(
f"Update learning rate from {_previous:.6f} to {self.lr_scheduler.get_last_lr()[0]:.6f}"
)
if epoch % self.snapshot_epoch == 0:
self.store_snapshot(epoch)
logger.debug(
f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs}"
)
def _run_epoch(self, data_loader, phase):
logger.debug(f"Run epoch in phase {phase}")
self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0
loss_updates = 0
for i, (inputs, labels) in enumerate(data_loader):
print(
f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}",
end="\r",
)
inputs = inputs.to(self.device)
labels = labels.to(self.device)
self.optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"):
outputs = self.model(inputs)
loss = self.loss_func(outputs, labels)
if phase == "train":
loss.backward()
self.optimizer.step()
epoch_loss += loss.item()
loss_updates += 1
return epoch_loss / loss_updates
def store_snapshot(self, epoch):
snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(self.model.state_dict(), snapshot_file)
if __name__ == "__main__":
import argparse
import random
import sys
import numpy as np
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.dataset.transform import ScaleTransform
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model Args
parser.add_argument(
"--features",
type=int,
nargs="+",
default=[64, 128, 256, 512],
help="number of features in the layers of the UNet structure",
)
# Dataset Args
parser.add_argument(
"--dataset_dir",
type=str,
default="data/initial/",
help="the directory where the dataset for training is found",
)
parser.add_argument(
"--input_norm",
type=str,
choices=["none", "mean", "max", "gaussian"],
default="mean",
help="type of normalization applied to the reconstructions",
)
parser.add_argument(
"--patch_size",
type=int,
default=32,
help="the size of patches extracted for each reconstruction",
)
parser.add_argument(
"--patch_offset",
type=int,
default=20,
help="offset to ignore the border of the image",
)
parser.add_argument(
"--number_of_patches",
type=int,
default=100,
help="number of patches extracted for each image",
)
parser.add_argument(
"--no_shuffle",
action="store_true",
help="do not shuffle patches in the dataset",
)
# Training Args
parser.add_argument(
"--seed",
type=int,
help="seed used for random number generation",
)
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="the batch size used for training",
)
parser.add_argument(
"--output_dir",
type=str,
default="train_data",
help="directory in which results (snapshots and logs) of this training are saved",
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="the number of epochs for which the model is trained",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--loss_func",
type=str,
default="l1",
help="define the loss function used for training, e.g. 0.75*l1+0.25*gdl",
)
parser.add_argument(
"--lr", type=float, default=0.001, help="the initial learning rate for training"
)
parser.add_argument(
"--lr_decay_factor",
type=float,
default=0.99,
help="decay factor for the learning rate",
)
parser.add_argument(
"--lr_decay_epoch",
type=int,
default=1,
help="frequency in epochs at which the learning rate is decayed",
)
parser.add_argument(
"--snapshot_dir",
type=str,
default="snapshots",
help="directory under --output_dir where snapshots are stored",
)
parser.add_argument(
"--snapshot_epoch",
type=int,
default=10,
help="frequency in epochs at which snapshots are stored",
)
# Logging Args
add_logging_args(parser, defaults={"--logfile": "train.log"})
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
else:
if len(os.listdir(args.snapshot_dir)) > 0:
print(
f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
)
print(f" Exit so that data is not accidentally overwritten!")
exit(1)
args.logfile = os.path.join(args.output_dir, args.logfile)
device = torch.device(args.device)
logger = get_logger_by_args(args)
logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
model = UNet(in_channels=1, features=args.features)
model = model.to(device)
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
data_loaders = {}
for split in ["train", "validation"]:
dataset = MuMapPatchDataset(
args.dataset_dir,
patches_per_image=args.number_of_patches,
patch_size=args.patch_size,
patch_offset=args.patch_offset,
shuffle=not args.no_shuffle,
split_name=split,
transform_normalization=transform_normalization,
logger=logger,
)
data_loader = torch.utils.data.DataLoader(
dataset=dataset,
batch_size=args.batch_size,
shuffle=True,
pin_memory=True,
num_workers=1,
)
data_loaders[split] = data_loader
criterion = WeightedLoss.from_str(args.loss_func)
logger.debug(f"Criterion: {criterion}")
training = Training(
model=model,
data_loaders=data_loaders,
epochs=args.epochs,
device=device,
loss_func=criterion,
lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_epoch=args.lr_decay_epoch,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
logger=logger,
)
training.run()