Newer
Older
def __init__(self, model, data_loaders, epochs, logger):
self.model = model
self.data_loaders = data_loaders
self.device = torch.device("cpu")
self.snapshot_dir = "tmp"
self.snapshot_epoch = 5
self.loss_func = torch.nn.MSELoss()
# self.lr = 1e-3
# self.lr_decay_factor = 0.99
self.lr = 0.1
self.lr_decay_factor = 0.5
self.lr_decay_epoch = 1
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=self.lr_decay_epoch, gamma=self.lr_decay_factor)
self.logger = logger
def run(self):
for epoch in range(1, self.epochs + 1):
logger.debug(f"Run epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} ...")
self._run_epoch(self.data_loaders["train"], phase="train")
loss_training = self._run_epoch(self.data_loaders["train"], phase="eval")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss TRAIN: {loss_training:.4f}")
loss_validation = self._run_epoch(self.data_loaders["validation"], phase="eval")
logger.info(f"Epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs} - Loss VAL: {loss_validation:.4f}")
_previous = self.lr_scheduler.get_last_lr()[0]
logger.debug(f"Update learning rate from {_previous:.4f} to {self.lr_scheduler.get_last_lr()[0]:.4f}")
if epoch % self.snapshot_epoch == 0:
logger.debug(f"Finished epoch {str(epoch):>{len(str(self.epochs))}}/{self.epochs + 1}")
def _run_epoch(self, data_loader, phase):
logger.debug(f"Run epoch in phase {phase}")
self.model.train() if phase == "train" else self.model.eval()
epoch_loss = 0
for i, (inputs, labels) in enumerate(data_loader):
print(f"Batch {str(i):>{len(str(len(data_loader)))}}/{len(data_loader)}", end="\r")
inputs = inputs.to(self.device)
labels = labels.to(self.device)
self.optimizer.zero_grad()
with torch.set_grad_enabled(phase == "train"):
outputs = self.model(inputs)
loss = self.loss_func(outputs, labels)
epoch_loss += loss.item() / inputs.shape[0]
return epoch_loss
def store_snapshot(self, epoch):
snapshot_file = f"{epoch:0{len(str(self.epochs))}d}.pth"
snapshot_file = os.path.join(self.snapshot_dir, snapshot_file)
logger.debug(f"Store snapshot at {snapshot_file}")
torch.save(self.model.state_dict(), snapshot_file)
if __name__ == "__main__":
from mu_map.data.mock import MuMapMockDataset
from mu_map.logging import get_logger
from mu_map.models.unet import UNet
logger = get_logger(logfile="train.log", loglevel="DEBUG")
model = UNet(in_channels=1, features=[8, 16])
print(model)
dataset = MuMapMockDataset()
data_loader_train = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
data_loader_val = torch.utils.data.DataLoader(dataset=dataset, batch_size=2, shuffle=True, pin_memory=True, num_workers=1)
data_loaders = {"train": data_loader_train, "validation": data_loader_val}
training = Training(model, data_loaders, 10, logger)
training.run()