-
Tamino Huxohl authoredTamino Huxohl authored
distance.py 7.34 KiB
import os
from typing import Dict
import torch
from mu_map.dataset.default import MuMapDataset
from mu_map.logging import get_logger
from mu_map.training.lib import TrainingParams, AbstractTraining
from mu_map.training.loss import WeightedLoss
class Training(AbstractTraining):
def __init__(
self,
epochs: int,
dataset: MuMapDataset,
batch_size: int,
device: torch.device,
snapshot_dir: str,
snapshot_epoch: int,
params: TrainingParams,
loss_func: WeightedLoss,
logger,
):
super().__init__(epochs, dataset, batch_size, device, snapshot_dir, snapshot_epoch, logger)
self.training_params.append(params)
self.loss_func = loss_func
self.model = params.model
def _train_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
mu_maps_predicted = self.model(recons)
loss = self.loss_func(mu_maps_predicted, mu_maps)
loss.backward()
return loss.item()
def _eval_batch(self, recons: torch.Tensor, mu_maps: torch.Tensor) -> float:
mu_maps_predicted = self.model(recons)
loss = torch.nn.functional.l1_loss(mu_maps_predicted, mu_maps)
return loss.item()
if __name__ == "__main__":
import argparse
import random
import sys
import numpy as np
from mu_map.dataset.patches import MuMapPatchDataset
from mu_map.dataset.normalization import (
MeanNormTransform,
MaxNormTransform,
GaussianNormTransform,
)
from mu_map.logging import add_logging_args, get_logger_by_args
from mu_map.models.unet import UNet
parser = argparse.ArgumentParser(
description="Train a UNet model to predict μ-maps from reconstructed scatter images",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
# Model Args
parser.add_argument(
"--features",
type=int,
nargs="+",
default=[64, 128, 256, 512],
help="number of features in the layers of the UNet structure",
)
# Dataset Args
parser.add_argument(
"--dataset_dir",
type=str,
default="data/initial/",
help="the directory where the dataset for training is found",
)
parser.add_argument(
"--input_norm",
type=str,
choices=["none", "mean", "max", "gaussian"],
default="mean",
help="type of normalization applied to the reconstructions",
)
parser.add_argument(
"--patch_size",
type=int,
default=32,
help="the size of patches extracted for each reconstruction",
)
parser.add_argument(
"--patch_offset",
type=int,
default=20,
help="offset to ignore the border of the image",
)
parser.add_argument(
"--number_of_patches",
type=int,
default=100,
help="number of patches extracted for each image",
)
parser.add_argument(
"--no_shuffle",
action="store_true",
help="do not shuffle patches in the dataset",
)
# Training Args
parser.add_argument(
"--seed",
type=int,
help="seed used for random number generation",
)
parser.add_argument(
"--batch_size",
type=int,
default=64,
help="the batch size used for training",
)
parser.add_argument(
"--output_dir",
type=str,
default="train_data",
help="directory in which results (snapshots and logs) of this training are saved",
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="the number of epochs for which the model is trained",
)
parser.add_argument(
"--device",
type=str,
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="the device (cpu or gpu) with which the training is performed",
)
parser.add_argument(
"--loss_func",
type=str,
default="l1",
help="define the loss function used for training, e.g. 0.75*l1+0.25*gdl",
)
parser.add_argument(
"--decay_lr",
action="store_true",
help="decay the learning rate",
)
parser.add_argument(
"--lr", type=float, default=0.001, help="the initial learning rate for training"
)
parser.add_argument(
"--lr_decay_factor",
type=float,
default=0.99,
help="decay factor for the learning rate",
)
parser.add_argument(
"--lr_decay_epoch",
type=int,
default=1,
help="frequency in epochs at which the learning rate is decayed",
)
parser.add_argument(
"--snapshot_dir",
type=str,
default="snapshots",
help="directory under --output_dir where snapshots are stored",
)
parser.add_argument(
"--snapshot_epoch",
type=int,
default=10,
help="frequency in epochs at which snapshots are stored",
)
# Logging Args
add_logging_args(parser, defaults={"--logfile": "train.log"})
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.mkdir(args.output_dir)
args.snapshot_dir = os.path.join(args.output_dir, args.snapshot_dir)
if not os.path.exists(args.snapshot_dir):
os.mkdir(args.snapshot_dir)
else:
if len(os.listdir(args.snapshot_dir)) > 0:
print(
f"ATTENTION: Snapshot directory [{args.snapshot_dir}] already exists and is not empty!"
)
print(f" Exit so that data is not accidentally overwritten!")
exit(1)
args.logfile = os.path.join(args.output_dir, args.logfile)
device = torch.device(args.device)
logger = get_logger_by_args(args)
logger.info(args)
args.seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
logger.info(f"Seed: {args.seed}")
random.seed(args.seed)
torch.manual_seed(args.seed)
np.random.seed(args.seed)
transform_normalization = None
if args.input_norm == "mean":
transform_normalization = MeanNormTransform()
elif args.input_norm == "max":
transform_normalization = MaxNormTransform()
elif args.input_norm == "gaussian":
transform_normalization = GaussianNormTransform()
dataset = MuMapPatchDataset(
args.dataset_dir,
patches_per_image=args.number_of_patches,
patch_size=args.patch_size,
patch_offset=args.patch_offset,
shuffle=not args.no_shuffle,
transform_normalization=transform_normalization,
logger=logger,
)
model = UNet(in_channels=1, features=args.features).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.5, 0.999))
lr_scheduler = (
torch.optim.lr_scheduler.StepLR(
optimizer, step_size=args.lr_decay_factor, gamma=args.lr_decay_factor
)
if args.decay_lr
else None
)
params = TrainingParams(name="Model", model=model, optimizer=optimizer, lr_scheduler=lr_scheduler)
criterion = WeightedLoss.from_str(args.loss_func)
training = Training(
epochs=args.epochs,
dataset=dataset,
batch_size=args.batch_size,
device=device,
snapshot_dir=args.snapshot_dir,
snapshot_epoch=args.snapshot_epoch,
params=params,
loss_func=criterion,
logger=logger,
)
training.run()